bartz 0.6.0__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bartz/BART.py +464 -254
- bartz/__init__.py +2 -2
- bartz/_version.py +1 -1
- bartz/debug.py +1259 -79
- bartz/grove.py +139 -93
- bartz/jaxext/__init__.py +213 -0
- bartz/jaxext/_autobatch.py +238 -0
- bartz/jaxext/scipy/__init__.py +25 -0
- bartz/jaxext/scipy/special.py +240 -0
- bartz/jaxext/scipy/stats.py +36 -0
- bartz/mcmcloop.py +468 -311
- bartz/mcmcstep.py +734 -453
- bartz/prepcovars.py +139 -43
- {bartz-0.6.0.dist-info → bartz-0.7.0.dist-info}/METADATA +2 -3
- bartz-0.7.0.dist-info/RECORD +17 -0
- {bartz-0.6.0.dist-info → bartz-0.7.0.dist-info}/WHEEL +1 -1
- bartz/jaxext.py +0 -423
- bartz-0.6.0.dist-info/RECORD +0 -13
bartz/grove.py
CHANGED
|
@@ -22,93 +22,113 @@
|
|
|
22
22
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
23
23
|
# SOFTWARE.
|
|
24
24
|
|
|
25
|
-
"""
|
|
25
|
+
"""Functions to create and manipulate binary decision trees."""
|
|
26
26
|
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
A decision tree is represented by tree arrays: 'leaf', 'var', and 'split'.
|
|
32
|
-
|
|
33
|
-
The 'leaf' array contains the values in the leaves.
|
|
27
|
+
import math
|
|
28
|
+
from functools import partial
|
|
29
|
+
from typing import Protocol
|
|
34
30
|
|
|
35
|
-
|
|
31
|
+
import jax
|
|
32
|
+
from jax import jit, lax
|
|
33
|
+
from jax import numpy as jnp
|
|
34
|
+
from jaxtyping import Array, Bool, DTypeLike, Float32, Int32, Real, Shaped, UInt
|
|
36
35
|
|
|
37
|
-
|
|
36
|
+
from bartz.jaxext import minimal_unsigned_dtype, vmap_nodoc
|
|
38
37
|
|
|
39
|
-
Since the nodes at the bottom can only be leaves and not decision nodes, the 'var' and 'split' arrays have half the length of the 'leaf' array.
|
|
40
38
|
|
|
41
|
-
|
|
39
|
+
class TreeHeaps(Protocol):
|
|
40
|
+
"""A protocol for dataclasses that represent trees.
|
|
42
41
|
|
|
43
|
-
|
|
44
|
-
|
|
42
|
+
A tree is represented with arrays as a heap. The root node is at index 1.
|
|
43
|
+
The children nodes of a node at index :math:`i` are at indices :math:`2i`
|
|
44
|
+
(left child) and :math:`2i + 1` (right child). The array element at index 0
|
|
45
|
+
is unused.
|
|
45
46
|
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
47
|
+
Parameters
|
|
48
|
+
----------
|
|
49
|
+
leaf_tree
|
|
50
|
+
The values in the leaves of the trees. This array can be dirty, i.e.,
|
|
51
|
+
unused nodes can have whatever value.
|
|
52
|
+
var_tree
|
|
53
|
+
The axes along which the decision nodes operate. This array can be
|
|
54
|
+
dirty but for the always unused node at index 0 which must be set to 0.
|
|
55
|
+
split_tree
|
|
56
|
+
The decision boundaries of the trees. The boundaries are open on the
|
|
57
|
+
right, i.e., a point belongs to the left child iff x < split. Whether a
|
|
58
|
+
node is a leaf is indicated by the corresponding 'split' element being
|
|
59
|
+
0. Unused nodes also have split set to 0. This array can't be dirty.
|
|
60
|
+
|
|
61
|
+
Notes
|
|
62
|
+
-----
|
|
63
|
+
Since the nodes at the bottom can only be leaves and not decision nodes,
|
|
64
|
+
`var_tree` and `split_tree` are half as long as `leaf_tree`.
|
|
65
|
+
"""
|
|
49
66
|
|
|
50
|
-
|
|
67
|
+
leaf_tree: Float32[Array, '* 2**d']
|
|
68
|
+
var_tree: UInt[Array, '* 2**(d-1)']
|
|
69
|
+
split_tree: UInt[Array, '* 2**(d-1)']
|
|
51
70
|
|
|
52
71
|
|
|
53
|
-
def make_tree(depth, dtype):
|
|
72
|
+
def make_tree(depth: int, dtype: DTypeLike) -> Shaped[Array, ' 2**{depth}']:
|
|
54
73
|
"""
|
|
55
74
|
Make an array to represent a binary tree.
|
|
56
75
|
|
|
57
76
|
Parameters
|
|
58
77
|
----------
|
|
59
|
-
depth
|
|
78
|
+
depth
|
|
60
79
|
The maximum depth of the tree. Depth 1 means that there is only a root
|
|
61
80
|
node.
|
|
62
|
-
dtype
|
|
81
|
+
dtype
|
|
63
82
|
The dtype of the array.
|
|
64
83
|
|
|
65
84
|
Returns
|
|
66
85
|
-------
|
|
67
|
-
|
|
68
|
-
An array of zeroes with shape (2 ** depth,).
|
|
86
|
+
An array of zeroes with the appropriate shape.
|
|
69
87
|
"""
|
|
70
88
|
return jnp.zeros(2**depth, dtype)
|
|
71
89
|
|
|
72
90
|
|
|
73
|
-
def tree_depth(tree):
|
|
91
|
+
def tree_depth(tree: Shaped[Array, '* 2**d']) -> int:
|
|
74
92
|
"""
|
|
75
93
|
Return the maximum depth of a tree.
|
|
76
94
|
|
|
77
95
|
Parameters
|
|
78
96
|
----------
|
|
79
|
-
tree
|
|
97
|
+
tree
|
|
80
98
|
A tree created by `make_tree`. If the array is ND, the tree structure is
|
|
81
99
|
assumed to be along the last axis.
|
|
82
100
|
|
|
83
101
|
Returns
|
|
84
102
|
-------
|
|
85
|
-
depth
|
|
86
|
-
The maximum depth of the tree.
|
|
103
|
+
The maximum depth of the tree.
|
|
87
104
|
"""
|
|
88
|
-
return
|
|
105
|
+
return round(math.log2(tree.shape[-1]))
|
|
89
106
|
|
|
90
107
|
|
|
91
|
-
def traverse_tree(
|
|
108
|
+
def traverse_tree(
|
|
109
|
+
x: Real[Array, ' p'],
|
|
110
|
+
var_tree: UInt[Array, ' 2**(d-1)'],
|
|
111
|
+
split_tree: UInt[Array, ' 2**(d-1)'],
|
|
112
|
+
) -> Int32[Array, '']:
|
|
92
113
|
"""
|
|
93
114
|
Find the leaf where a point falls into.
|
|
94
115
|
|
|
95
116
|
Parameters
|
|
96
117
|
----------
|
|
97
|
-
x
|
|
118
|
+
x
|
|
98
119
|
The coordinates to evaluate the tree at.
|
|
99
|
-
var_tree
|
|
120
|
+
var_tree
|
|
100
121
|
The decision axes of the tree.
|
|
101
|
-
split_tree
|
|
122
|
+
split_tree
|
|
102
123
|
The decision boundaries of the tree.
|
|
103
124
|
|
|
104
125
|
Returns
|
|
105
126
|
-------
|
|
106
|
-
index
|
|
107
|
-
The index of the leaf.
|
|
127
|
+
The index of the leaf.
|
|
108
128
|
"""
|
|
109
129
|
carry = (
|
|
110
130
|
jnp.zeros((), bool),
|
|
111
|
-
jnp.ones((),
|
|
131
|
+
jnp.ones((), minimal_unsigned_dtype(2 * var_tree.size - 1)),
|
|
112
132
|
)
|
|
113
133
|
|
|
114
134
|
def loop(carry, _):
|
|
@@ -128,111 +148,107 @@ def traverse_tree(x, var_tree, split_tree):
|
|
|
128
148
|
return index
|
|
129
149
|
|
|
130
150
|
|
|
131
|
-
@
|
|
132
|
-
@
|
|
133
|
-
def traverse_forest(
|
|
151
|
+
@partial(vmap_nodoc, in_axes=(None, 0, 0))
|
|
152
|
+
@partial(vmap_nodoc, in_axes=(1, None, None))
|
|
153
|
+
def traverse_forest(
|
|
154
|
+
X: Real[Array, 'p n'],
|
|
155
|
+
var_trees: UInt[Array, 'm 2**(d-1)'],
|
|
156
|
+
split_trees: UInt[Array, 'm 2**(d-1)'],
|
|
157
|
+
) -> Int32[Array, 'm n']:
|
|
134
158
|
"""
|
|
135
159
|
Find the leaves where points fall into.
|
|
136
160
|
|
|
137
161
|
Parameters
|
|
138
162
|
----------
|
|
139
|
-
X
|
|
163
|
+
X
|
|
140
164
|
The coordinates to evaluate the trees at.
|
|
141
|
-
var_trees
|
|
165
|
+
var_trees
|
|
142
166
|
The decision axes of the trees.
|
|
143
|
-
split_trees
|
|
167
|
+
split_trees
|
|
144
168
|
The decision boundaries of the trees.
|
|
145
169
|
|
|
146
170
|
Returns
|
|
147
171
|
-------
|
|
148
|
-
indices
|
|
149
|
-
The indices of the leaves.
|
|
172
|
+
The indices of the leaves.
|
|
150
173
|
"""
|
|
151
174
|
return traverse_tree(X, var_trees, split_trees)
|
|
152
175
|
|
|
153
176
|
|
|
154
|
-
def evaluate_forest(
|
|
177
|
+
def evaluate_forest(
|
|
178
|
+
X: UInt[Array, 'p n'], trees: TreeHeaps, *, sum_trees: bool = True
|
|
179
|
+
) -> Float32[Array, ' n'] | Float32[Array, 'm n']:
|
|
155
180
|
"""
|
|
156
181
|
Evaluate a ensemble of trees at an array of points.
|
|
157
182
|
|
|
158
183
|
Parameters
|
|
159
184
|
----------
|
|
160
|
-
X
|
|
185
|
+
X
|
|
161
186
|
The coordinates to evaluate the trees at.
|
|
162
|
-
|
|
163
|
-
The
|
|
164
|
-
|
|
165
|
-
var_trees : array (m, 2 ** (d - 1))
|
|
166
|
-
The decision axes of the trees.
|
|
167
|
-
split_trees : array (m, 2 ** (d - 1))
|
|
168
|
-
The decision boundaries of the trees.
|
|
169
|
-
dtype : dtype, optional
|
|
170
|
-
The dtype of the output. Ignored if `sum_trees` is `False`.
|
|
171
|
-
sum_trees : bool, default True
|
|
187
|
+
trees
|
|
188
|
+
The tree heaps, with batch shape (m,).
|
|
189
|
+
sum_trees
|
|
172
190
|
Whether to sum the values across trees.
|
|
173
191
|
|
|
174
192
|
Returns
|
|
175
193
|
-------
|
|
176
|
-
|
|
177
|
-
The (sum of) the values of the trees at the points in `X`.
|
|
194
|
+
The (sum of) the values of the trees at the points in `X`.
|
|
178
195
|
"""
|
|
179
|
-
indices = traverse_forest(X,
|
|
180
|
-
ntree, _ =
|
|
181
|
-
tree_index = jnp.arange(ntree, dtype=
|
|
182
|
-
leaves =
|
|
196
|
+
indices = traverse_forest(X, trees.var_tree, trees.split_tree)
|
|
197
|
+
ntree, _ = trees.leaf_tree.shape
|
|
198
|
+
tree_index = jnp.arange(ntree, dtype=minimal_unsigned_dtype(ntree - 1))
|
|
199
|
+
leaves = trees.leaf_tree[tree_index[:, None], indices]
|
|
183
200
|
if sum_trees:
|
|
184
|
-
return jnp.sum(leaves, axis=0, dtype=
|
|
201
|
+
return jnp.sum(leaves, axis=0, dtype=jnp.float32)
|
|
185
202
|
# this sum suggests to swap the vmaps, but I think it's better for X
|
|
186
203
|
# copying to keep it that way
|
|
187
204
|
else:
|
|
188
205
|
return leaves
|
|
189
206
|
|
|
190
207
|
|
|
191
|
-
def is_actual_leaf(
|
|
208
|
+
def is_actual_leaf(
|
|
209
|
+
split_tree: UInt[Array, ' 2**(d-1)'], *, add_bottom_level: bool = False
|
|
210
|
+
) -> Bool[Array, ' 2**(d-1)'] | Bool[Array, ' 2**d']:
|
|
192
211
|
"""
|
|
193
212
|
Return a mask indicating the leaf nodes in a tree.
|
|
194
213
|
|
|
195
214
|
Parameters
|
|
196
215
|
----------
|
|
197
|
-
split_tree
|
|
216
|
+
split_tree
|
|
198
217
|
The splitting points of the tree.
|
|
199
|
-
add_bottom_level
|
|
218
|
+
add_bottom_level
|
|
200
219
|
If True, the bottom level of the tree is also considered.
|
|
201
220
|
|
|
202
221
|
Returns
|
|
203
222
|
-------
|
|
204
|
-
|
|
205
|
-
The mask indicating the leaf nodes. The length is doubled if
|
|
206
|
-
`add_bottom_level` is True.
|
|
223
|
+
The mask marking the leaf nodes. Length doubled if `add_bottom_level` is True.
|
|
207
224
|
"""
|
|
208
225
|
size = split_tree.size
|
|
209
226
|
is_leaf = split_tree == 0
|
|
210
227
|
if add_bottom_level:
|
|
211
228
|
size *= 2
|
|
212
229
|
is_leaf = jnp.concatenate([is_leaf, jnp.ones_like(is_leaf)])
|
|
213
|
-
index = jnp.arange(size, dtype=
|
|
230
|
+
index = jnp.arange(size, dtype=minimal_unsigned_dtype(size - 1))
|
|
214
231
|
parent_index = index >> 1
|
|
215
232
|
parent_nonleaf = split_tree[parent_index].astype(bool)
|
|
216
233
|
parent_nonleaf = parent_nonleaf.at[1].set(True)
|
|
217
234
|
return is_leaf & parent_nonleaf
|
|
218
235
|
|
|
219
236
|
|
|
220
|
-
def is_leaves_parent(split_tree):
|
|
237
|
+
def is_leaves_parent(split_tree: UInt[Array, ' 2**(d-1)']) -> Bool[Array, ' 2**(d-1)']:
|
|
221
238
|
"""
|
|
222
239
|
Return a mask indicating the nodes with leaf (and only leaf) children.
|
|
223
240
|
|
|
224
241
|
Parameters
|
|
225
242
|
----------
|
|
226
|
-
split_tree
|
|
243
|
+
split_tree
|
|
227
244
|
The decision boundaries of the tree.
|
|
228
245
|
|
|
229
246
|
Returns
|
|
230
247
|
-------
|
|
231
|
-
|
|
232
|
-
The mask indicating which nodes have leaf children.
|
|
248
|
+
The mask indicating which nodes have leaf children.
|
|
233
249
|
"""
|
|
234
250
|
index = jnp.arange(
|
|
235
|
-
split_tree.size, dtype=
|
|
251
|
+
split_tree.size, dtype=minimal_unsigned_dtype(2 * split_tree.size - 1)
|
|
236
252
|
)
|
|
237
253
|
left_index = index << 1 # left child
|
|
238
254
|
right_index = left_index + 1 # right child
|
|
@@ -243,21 +259,24 @@ def is_leaves_parent(split_tree):
|
|
|
243
259
|
# the 0-th item has split == 0, so it's not counted
|
|
244
260
|
|
|
245
261
|
|
|
246
|
-
def tree_depths(tree_length):
|
|
262
|
+
def tree_depths(tree_length: int) -> Int32[Array, ' {tree_length}']:
|
|
247
263
|
"""
|
|
248
264
|
Return the depth of each node in a binary tree.
|
|
249
265
|
|
|
250
266
|
Parameters
|
|
251
267
|
----------
|
|
252
|
-
tree_length
|
|
268
|
+
tree_length
|
|
253
269
|
The length of the tree array, i.e., 2 ** d.
|
|
254
270
|
|
|
255
271
|
Returns
|
|
256
272
|
-------
|
|
257
|
-
depth
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
273
|
+
The depth of each node.
|
|
274
|
+
|
|
275
|
+
Notes
|
|
276
|
+
-----
|
|
277
|
+
The root node (index 1) has depth 0. The depth is the position of the most
|
|
278
|
+
significant non-zero bit in the index. The first element (the unused node)
|
|
279
|
+
is marked as depth 0.
|
|
261
280
|
"""
|
|
262
281
|
depths = []
|
|
263
282
|
depth = 0
|
|
@@ -266,22 +285,21 @@ def tree_depths(tree_length):
|
|
|
266
285
|
depth += 1
|
|
267
286
|
depths.append(depth - 1)
|
|
268
287
|
depths[0] = 0
|
|
269
|
-
return jnp.array(depths,
|
|
288
|
+
return jnp.array(depths, minimal_unsigned_dtype(max(depths)))
|
|
270
289
|
|
|
271
290
|
|
|
272
|
-
def is_used(split_tree):
|
|
291
|
+
def is_used(split_tree: UInt[Array, ' 2**(d-1)']) -> Bool[Array, ' 2**d']:
|
|
273
292
|
"""
|
|
274
293
|
Return a mask indicating the used nodes in a tree.
|
|
275
294
|
|
|
276
295
|
Parameters
|
|
277
296
|
----------
|
|
278
|
-
split_tree
|
|
297
|
+
split_tree
|
|
279
298
|
The decision boundaries of the tree.
|
|
280
299
|
|
|
281
300
|
Returns
|
|
282
301
|
-------
|
|
283
|
-
|
|
284
|
-
A mask indicating which nodes are actually used.
|
|
302
|
+
A mask indicating which nodes are actually used.
|
|
285
303
|
"""
|
|
286
304
|
internal_node = split_tree.astype(bool)
|
|
287
305
|
internal_node = jnp.concatenate([internal_node, jnp.zeros_like(internal_node)])
|
|
@@ -289,22 +307,50 @@ def is_used(split_tree):
|
|
|
289
307
|
return internal_node | actual_leaf
|
|
290
308
|
|
|
291
309
|
|
|
292
|
-
|
|
310
|
+
@jit
|
|
311
|
+
def forest_fill(split_tree: UInt[Array, 'num_trees 2**(d-1)']) -> Float32[Array, '']:
|
|
293
312
|
"""
|
|
294
313
|
Return the fraction of used nodes in a set of trees.
|
|
295
314
|
|
|
296
315
|
Parameters
|
|
297
316
|
----------
|
|
298
|
-
|
|
317
|
+
split_tree
|
|
299
318
|
The decision boundaries of the trees.
|
|
300
319
|
|
|
301
320
|
Returns
|
|
302
321
|
-------
|
|
303
|
-
|
|
304
|
-
The number of tree nodes in the forest over the maximum number that
|
|
305
|
-
could be stored in the arrays.
|
|
322
|
+
Number of tree nodes over the maximum number that could be stored.
|
|
306
323
|
"""
|
|
307
|
-
|
|
308
|
-
used = jax.vmap(is_used)(
|
|
324
|
+
num_trees, _ = split_tree.shape
|
|
325
|
+
used = jax.vmap(is_used)(split_tree)
|
|
309
326
|
count = jnp.count_nonzero(used)
|
|
310
|
-
return count / (used.size -
|
|
327
|
+
return count / (used.size - num_trees)
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
def var_histogram(
|
|
331
|
+
p: int, var_tree: UInt[Array, '* 2**(d-1)'], split_tree: UInt[Array, '* 2**(d-1)']
|
|
332
|
+
) -> Int32[Array, ' {p}']:
|
|
333
|
+
"""
|
|
334
|
+
Count how many times each variable appears in a tree.
|
|
335
|
+
|
|
336
|
+
Parameters
|
|
337
|
+
----------
|
|
338
|
+
p
|
|
339
|
+
The number of variables (the maximum value that can occur in
|
|
340
|
+
`var_tree` is ``p - 1``).
|
|
341
|
+
var_tree
|
|
342
|
+
The decision axes of the tree.
|
|
343
|
+
split_tree
|
|
344
|
+
The decision boundaries of the tree.
|
|
345
|
+
|
|
346
|
+
Returns
|
|
347
|
+
-------
|
|
348
|
+
The histogram of the variables used in the tree.
|
|
349
|
+
|
|
350
|
+
Notes
|
|
351
|
+
-----
|
|
352
|
+
If there are leading axes in the tree arrays (i.e., multiple trees), the
|
|
353
|
+
returned counts are cumulative over trees.
|
|
354
|
+
"""
|
|
355
|
+
is_internal = split_tree.astype(bool)
|
|
356
|
+
return jnp.zeros(p, int).at[var_tree].add(is_internal)
|
bartz/jaxext/__init__.py
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
1
|
+
# bartz/src/bartz/jaxext/__init__.py
|
|
2
|
+
#
|
|
3
|
+
# Copyright (c) 2024-2025, Giacomo Petrillo
|
|
4
|
+
#
|
|
5
|
+
# This file is part of bartz.
|
|
6
|
+
#
|
|
7
|
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
8
|
+
# of this software and associated documentation files (the "Software"), to deal
|
|
9
|
+
# in the Software without restriction, including without limitation the rights
|
|
10
|
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
11
|
+
# copies of the Software, and to permit persons to whom the Software is
|
|
12
|
+
# furnished to do so, subject to the following conditions:
|
|
13
|
+
#
|
|
14
|
+
# The above copyright notice and this permission notice shall be included in all
|
|
15
|
+
# copies or substantial portions of the Software.
|
|
16
|
+
#
|
|
17
|
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
18
|
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
19
|
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
20
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
21
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
22
|
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
23
|
+
# SOFTWARE.
|
|
24
|
+
|
|
25
|
+
"""Additions to jax."""
|
|
26
|
+
|
|
27
|
+
import functools
|
|
28
|
+
import math
|
|
29
|
+
from collections.abc import Sequence
|
|
30
|
+
|
|
31
|
+
import jax
|
|
32
|
+
from jax import numpy as jnp
|
|
33
|
+
from jax import random
|
|
34
|
+
from jax.lax import scan
|
|
35
|
+
from jax.scipy.special import ndtr
|
|
36
|
+
from jaxtyping import Array, Bool, Float32, Key, Scalar, Shaped
|
|
37
|
+
|
|
38
|
+
from bartz.jaxext._autobatch import autobatch # noqa: F401
|
|
39
|
+
from bartz.jaxext.scipy.special import ndtri
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def vmap_nodoc(fun, *args, **kw):
|
|
43
|
+
"""
|
|
44
|
+
Acts like `jax.vmap` but preserves the docstring of the function unchanged.
|
|
45
|
+
|
|
46
|
+
This is useful if the docstring already takes into account that the
|
|
47
|
+
arguments have additional axes due to vmap.
|
|
48
|
+
"""
|
|
49
|
+
doc = fun.__doc__
|
|
50
|
+
fun = jax.vmap(fun, *args, **kw)
|
|
51
|
+
fun.__doc__ = doc
|
|
52
|
+
return fun
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def minimal_unsigned_dtype(value):
|
|
56
|
+
"""Return the smallest unsigned integer dtype that can represent `value`."""
|
|
57
|
+
if value < 2**8:
|
|
58
|
+
return jnp.uint8
|
|
59
|
+
if value < 2**16:
|
|
60
|
+
return jnp.uint16
|
|
61
|
+
if value < 2**32:
|
|
62
|
+
return jnp.uint32
|
|
63
|
+
return jnp.uint64
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@functools.partial(jax.jit, static_argnums=(1,))
|
|
67
|
+
def unique(
|
|
68
|
+
x: Shaped[Array, ' _'], size: int, fill_value: Scalar
|
|
69
|
+
) -> tuple[Shaped[Array, ' {size}'], int]:
|
|
70
|
+
"""
|
|
71
|
+
Restricted version of `jax.numpy.unique` that uses less memory.
|
|
72
|
+
|
|
73
|
+
Parameters
|
|
74
|
+
----------
|
|
75
|
+
x
|
|
76
|
+
The input array.
|
|
77
|
+
size
|
|
78
|
+
The length of the output.
|
|
79
|
+
fill_value
|
|
80
|
+
The value to fill the output with if `size` is greater than the number
|
|
81
|
+
of unique values in `x`.
|
|
82
|
+
|
|
83
|
+
Returns
|
|
84
|
+
-------
|
|
85
|
+
out : Shaped[Array, '{size}']
|
|
86
|
+
The unique values in `x`, sorted, and right-padded with `fill_value`.
|
|
87
|
+
actual_length : int
|
|
88
|
+
The number of used values in `out`.
|
|
89
|
+
"""
|
|
90
|
+
if x.size == 0:
|
|
91
|
+
return jnp.full(size, fill_value, x.dtype), 0
|
|
92
|
+
if size == 0:
|
|
93
|
+
return jnp.empty(0, x.dtype), 0
|
|
94
|
+
x = jnp.sort(x)
|
|
95
|
+
|
|
96
|
+
def loop(carry, x):
|
|
97
|
+
i_out, last, out = carry
|
|
98
|
+
i_out = jnp.where(x == last, i_out, i_out + 1)
|
|
99
|
+
out = out.at[i_out].set(x)
|
|
100
|
+
return (i_out, x, out), None
|
|
101
|
+
|
|
102
|
+
carry = 0, x[0], jnp.full(size, fill_value, x.dtype)
|
|
103
|
+
(actual_length, _, out), _ = scan(loop, carry, x[:size])
|
|
104
|
+
return out, actual_length + 1
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
class split:
|
|
108
|
+
"""
|
|
109
|
+
Split a key into `num` keys.
|
|
110
|
+
|
|
111
|
+
Parameters
|
|
112
|
+
----------
|
|
113
|
+
key
|
|
114
|
+
The key to split.
|
|
115
|
+
num
|
|
116
|
+
The number of keys to split into.
|
|
117
|
+
"""
|
|
118
|
+
|
|
119
|
+
def __init__(self, key: Key[Array, ''], num: int = 2):
|
|
120
|
+
self._keys = random.split(key, num)
|
|
121
|
+
|
|
122
|
+
def __len__(self):
|
|
123
|
+
return self._keys.size
|
|
124
|
+
|
|
125
|
+
def pop(self, shape: int | tuple[int, ...] | None = None) -> Key[Array, '*']:
|
|
126
|
+
"""
|
|
127
|
+
Pop one or more keys from the list.
|
|
128
|
+
|
|
129
|
+
Parameters
|
|
130
|
+
----------
|
|
131
|
+
shape
|
|
132
|
+
The shape of the keys to pop. If `None`, a single key is popped.
|
|
133
|
+
If an integer, that many keys are popped. If a tuple, the keys are
|
|
134
|
+
reshaped to that shape.
|
|
135
|
+
|
|
136
|
+
Returns
|
|
137
|
+
-------
|
|
138
|
+
The popped keys as a jax array with the requested shape.
|
|
139
|
+
|
|
140
|
+
Raises
|
|
141
|
+
------
|
|
142
|
+
IndexError
|
|
143
|
+
If `shape` is larger than the number of keys left in the list.
|
|
144
|
+
|
|
145
|
+
Notes
|
|
146
|
+
-----
|
|
147
|
+
The keys are popped from the beginning of the list, so for example
|
|
148
|
+
``list(keys.pop(2))`` is equivalent to ``[keys.pop(), keys.pop()]``.
|
|
149
|
+
"""
|
|
150
|
+
if shape is None:
|
|
151
|
+
shape = ()
|
|
152
|
+
elif not isinstance(shape, tuple):
|
|
153
|
+
shape = (shape,)
|
|
154
|
+
size_to_pop = math.prod(shape)
|
|
155
|
+
if size_to_pop > self._keys.size:
|
|
156
|
+
msg = f'Cannot pop {size_to_pop} keys from {self._keys.size} keys'
|
|
157
|
+
raise IndexError(msg)
|
|
158
|
+
popped_keys = self._keys[:size_to_pop]
|
|
159
|
+
self._keys = self._keys[size_to_pop:]
|
|
160
|
+
return popped_keys.reshape(shape)
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def truncated_normal_onesided(
|
|
164
|
+
key: Key[Array, ''],
|
|
165
|
+
shape: Sequence[int],
|
|
166
|
+
upper: Bool[Array, '*'],
|
|
167
|
+
bound: Float32[Array, '*'],
|
|
168
|
+
) -> Float32[Array, '*']:
|
|
169
|
+
"""
|
|
170
|
+
Sample from a one-sided truncated standard normal distribution.
|
|
171
|
+
|
|
172
|
+
Parameters
|
|
173
|
+
----------
|
|
174
|
+
key
|
|
175
|
+
JAX random key.
|
|
176
|
+
shape
|
|
177
|
+
Shape of output array, broadcasted with other inputs.
|
|
178
|
+
upper
|
|
179
|
+
True for (-∞, bound], False for [bound, ∞).
|
|
180
|
+
bound
|
|
181
|
+
The truncation boundary.
|
|
182
|
+
|
|
183
|
+
Returns
|
|
184
|
+
-------
|
|
185
|
+
Array of samples from the truncated normal distribution.
|
|
186
|
+
"""
|
|
187
|
+
# Pseudocode:
|
|
188
|
+
# | if upper:
|
|
189
|
+
# | if bound < 0:
|
|
190
|
+
# | ndtri(uniform(0, ndtr(bound))) =
|
|
191
|
+
# | ndtri(ndtr(bound) * u)
|
|
192
|
+
# | if bound > 0:
|
|
193
|
+
# | -ndtri(uniform(ndtr(-bound), 1)) =
|
|
194
|
+
# | -ndtri(ndtr(-bound) + ndtr(bound) * (1 - u))
|
|
195
|
+
# | if not upper:
|
|
196
|
+
# | if bound < 0:
|
|
197
|
+
# | ndtri(uniform(ndtr(bound), 1)) =
|
|
198
|
+
# | ndtri(ndtr(bound) + ndtr(-bound) * (1 - u))
|
|
199
|
+
# | if bound > 0:
|
|
200
|
+
# | -ndtri(uniform(0, ndtr(-bound))) =
|
|
201
|
+
# | -ndtri(ndtr(-bound) * u)
|
|
202
|
+
shape = jnp.broadcast_shapes(shape, upper.shape, bound.shape)
|
|
203
|
+
bound_pos = bound > 0
|
|
204
|
+
ndtr_bound = ndtr(bound)
|
|
205
|
+
ndtr_neg_bound = ndtr(-bound)
|
|
206
|
+
scale = jnp.where(upper, ndtr_bound, ndtr_neg_bound)
|
|
207
|
+
shift = jnp.where(upper, ndtr_neg_bound, ndtr_bound)
|
|
208
|
+
u = random.uniform(key, shape)
|
|
209
|
+
left_u = scale * (1 - u) # ~ uniform in (0, ndtr(±bound)]
|
|
210
|
+
right_u = shift + scale * u # ~ uniform in [ndtr(∓bound), 1)
|
|
211
|
+
truncated_u = jnp.where(upper ^ bound_pos, left_u, right_u)
|
|
212
|
+
truncated_norm = ndtri(truncated_u)
|
|
213
|
+
return jnp.where(bound_pos, -truncated_norm, truncated_norm)
|