bartz 0.4.1__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bartz/.DS_Store +0 -0
- bartz/BART.py +266 -113
- bartz/__init__.py +4 -12
- bartz/_version.py +1 -1
- bartz/debug.py +42 -16
- bartz/grove.py +62 -12
- bartz/jaxext.py +111 -37
- bartz/mcmcloop.py +419 -105
- bartz/mcmcstep.py +1528 -760
- bartz/prepcovars.py +25 -10
- {bartz-0.4.1.dist-info → bartz-0.6.0.dist-info}/METADATA +14 -16
- bartz-0.6.0.dist-info/RECORD +13 -0
- bartz-0.6.0.dist-info/WHEEL +4 -0
- bartz-0.4.1.dist-info/LICENSE +0 -21
- bartz-0.4.1.dist-info/RECORD +0 -13
- bartz-0.4.1.dist-info/WHEEL +0 -4
bartz/debug.py
CHANGED
|
@@ -1,21 +1,19 @@
|
|
|
1
1
|
import functools
|
|
2
2
|
|
|
3
3
|
import jax
|
|
4
|
-
from jax import numpy as jnp
|
|
5
4
|
from jax import lax
|
|
5
|
+
from jax import numpy as jnp
|
|
6
6
|
|
|
7
|
-
from . import grove
|
|
8
|
-
from . import mcmcstep
|
|
9
|
-
from . import jaxext
|
|
7
|
+
from . import grove, jaxext
|
|
10
8
|
|
|
11
|
-
def print_tree(leaf_tree, var_tree, split_tree, print_all=False):
|
|
12
9
|
|
|
10
|
+
def print_tree(leaf_tree, var_tree, split_tree, print_all=False):
|
|
13
11
|
tee = '├──'
|
|
14
12
|
corner = '└──'
|
|
15
13
|
join = '│ '
|
|
16
14
|
space = ' '
|
|
17
15
|
down = '┐'
|
|
18
|
-
bottom = '╢'
|
|
16
|
+
bottom = '╢' # '┨' #
|
|
19
17
|
|
|
20
18
|
def traverse_tree(index, depth, indent, first_indent, next_indent, unused):
|
|
21
19
|
if index >= len(leaf_tree):
|
|
@@ -58,7 +56,7 @@ def print_tree(leaf_tree, var_tree, split_tree, print_all=False):
|
|
|
58
56
|
|
|
59
57
|
indent += next_indent
|
|
60
58
|
unused = unused or is_leaf
|
|
61
|
-
|
|
59
|
+
|
|
62
60
|
if unused and not print_all:
|
|
63
61
|
return
|
|
64
62
|
|
|
@@ -67,58 +65,80 @@ def print_tree(leaf_tree, var_tree, split_tree, print_all=False):
|
|
|
67
65
|
|
|
68
66
|
traverse_tree(1, 0, '', '', '', False)
|
|
69
67
|
|
|
68
|
+
|
|
70
69
|
def tree_actual_depth(split_tree):
|
|
71
70
|
is_leaf = grove.is_actual_leaf(split_tree, add_bottom_level=True)
|
|
72
71
|
depth = grove.tree_depths(is_leaf.size)
|
|
73
72
|
depth = jnp.where(is_leaf, depth, 0)
|
|
74
73
|
return jnp.max(depth)
|
|
75
74
|
|
|
75
|
+
|
|
76
76
|
def forest_depth_distr(split_trees):
|
|
77
77
|
depth = grove.tree_depth(split_trees) + 1
|
|
78
78
|
depths = jax.vmap(tree_actual_depth)(split_trees)
|
|
79
79
|
return jnp.bincount(depths, length=depth)
|
|
80
80
|
|
|
81
|
+
|
|
81
82
|
def trace_depth_distr(split_trees_trace):
|
|
82
83
|
return jax.vmap(forest_depth_distr)(split_trees_trace)
|
|
83
84
|
|
|
85
|
+
|
|
84
86
|
def points_per_leaf_distr(var_tree, split_tree, X):
|
|
85
87
|
traverse_tree = jax.vmap(grove.traverse_tree, in_axes=(1, None, None))
|
|
86
88
|
indices = traverse_tree(X, var_tree, split_tree)
|
|
87
|
-
count_tree = jnp.zeros(
|
|
89
|
+
count_tree = jnp.zeros(
|
|
90
|
+
2 * split_tree.size, dtype=jaxext.minimal_unsigned_dtype(indices.size)
|
|
91
|
+
)
|
|
88
92
|
count_tree = count_tree.at[indices].add(1)
|
|
89
93
|
is_leaf = grove.is_actual_leaf(split_tree, add_bottom_level=True).view(jnp.uint8)
|
|
90
94
|
return jnp.bincount(count_tree, is_leaf, length=X.shape[1] + 1)
|
|
91
95
|
|
|
96
|
+
|
|
92
97
|
def forest_points_per_leaf_distr(bart, X):
|
|
93
98
|
distr = jnp.zeros(X.shape[1] + 1, int)
|
|
94
99
|
trees = bart['var_trees'], bart['split_trees']
|
|
100
|
+
|
|
95
101
|
def loop(distr, tree):
|
|
96
102
|
return distr + points_per_leaf_distr(*tree, X), None
|
|
103
|
+
|
|
97
104
|
distr, _ = lax.scan(loop, distr, trees)
|
|
98
105
|
return distr
|
|
99
106
|
|
|
107
|
+
|
|
100
108
|
def trace_points_per_leaf_distr(bart, X):
|
|
101
109
|
def loop(_, bart):
|
|
102
110
|
return None, forest_points_per_leaf_distr(bart, X)
|
|
111
|
+
|
|
103
112
|
_, distr = lax.scan(loop, None, bart)
|
|
104
113
|
return distr
|
|
105
114
|
|
|
115
|
+
|
|
106
116
|
def check_types(leaf_tree, var_tree, split_tree, max_split):
|
|
107
117
|
expected_var_dtype = jaxext.minimal_unsigned_dtype(max_split.size - 1)
|
|
108
118
|
expected_split_dtype = max_split.dtype
|
|
109
|
-
return
|
|
119
|
+
return (
|
|
120
|
+
var_tree.dtype == expected_var_dtype
|
|
121
|
+
and split_tree.dtype == expected_split_dtype
|
|
122
|
+
)
|
|
123
|
+
|
|
110
124
|
|
|
111
125
|
def check_sizes(leaf_tree, var_tree, split_tree, max_split):
|
|
112
126
|
return leaf_tree.size == 2 * var_tree.size == 2 * split_tree.size
|
|
113
127
|
|
|
128
|
+
|
|
114
129
|
def check_unused_node(leaf_tree, var_tree, split_tree, max_split):
|
|
115
130
|
return (var_tree[0] == 0) & (split_tree[0] == 0)
|
|
116
131
|
|
|
132
|
+
|
|
117
133
|
def check_leaf_values(leaf_tree, var_tree, split_tree, max_split):
|
|
118
134
|
return jnp.all(jnp.isfinite(leaf_tree))
|
|
119
135
|
|
|
136
|
+
|
|
120
137
|
def check_stray_nodes(leaf_tree, var_tree, split_tree, max_split):
|
|
121
|
-
index = jnp.arange(
|
|
138
|
+
index = jnp.arange(
|
|
139
|
+
2 * split_tree.size,
|
|
140
|
+
dtype=jaxext.minimal_unsigned_dtype(2 * split_tree.size - 1),
|
|
141
|
+
)
|
|
122
142
|
parent_index = index >> 1
|
|
123
143
|
is_not_leaf = split_tree.at[index].get(mode='fill', fill_value=0) != 0
|
|
124
144
|
parent_is_leaf = split_tree[parent_index] == 0
|
|
@@ -126,6 +146,7 @@ def check_stray_nodes(leaf_tree, var_tree, split_tree, max_split):
|
|
|
126
146
|
stray = stray.at[1].set(False)
|
|
127
147
|
return ~jnp.any(stray)
|
|
128
148
|
|
|
149
|
+
|
|
129
150
|
check_functions = [
|
|
130
151
|
check_types,
|
|
131
152
|
check_sizes,
|
|
@@ -134,6 +155,7 @@ check_functions = [
|
|
|
134
155
|
check_stray_nodes,
|
|
135
156
|
]
|
|
136
157
|
|
|
158
|
+
|
|
137
159
|
def check_tree(leaf_tree, var_tree, split_tree, max_split):
|
|
138
160
|
error_type = jaxext.minimal_unsigned_dtype(2 ** len(check_functions) - 1)
|
|
139
161
|
error = error_type(0)
|
|
@@ -144,15 +166,19 @@ def check_tree(leaf_tree, var_tree, split_tree, max_split):
|
|
|
144
166
|
error |= bit
|
|
145
167
|
return error
|
|
146
168
|
|
|
169
|
+
|
|
147
170
|
def describe_error(error):
|
|
148
|
-
return [
|
|
149
|
-
|
|
150
|
-
for i, func in enumerate(check_functions)
|
|
151
|
-
if error & (1 << i)
|
|
152
|
-
]
|
|
171
|
+
return [func.__name__ for i, func in enumerate(check_functions) if error & (1 << i)]
|
|
172
|
+
|
|
153
173
|
|
|
154
174
|
check_forest = jax.vmap(check_tree, in_axes=(0, 0, 0, None))
|
|
155
175
|
|
|
176
|
+
|
|
156
177
|
@functools.partial(jax.vmap, in_axes=(0, None))
|
|
157
178
|
def check_trace(trace, state):
|
|
158
|
-
return check_forest(
|
|
179
|
+
return check_forest(
|
|
180
|
+
trace['leaf_trees'],
|
|
181
|
+
trace['var_trees'],
|
|
182
|
+
trace['split_trees'],
|
|
183
|
+
state.max_split,
|
|
184
|
+
)
|
bartz/grove.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
# bartz/src/bartz/grove.py
|
|
2
2
|
#
|
|
3
|
-
# Copyright (c) 2024, Giacomo Petrillo
|
|
3
|
+
# Copyright (c) 2024-2025, Giacomo Petrillo
|
|
4
4
|
#
|
|
5
5
|
# This file is part of bartz.
|
|
6
6
|
#
|
|
@@ -34,7 +34,7 @@ The 'leaf' array contains the values in the leaves.
|
|
|
34
34
|
|
|
35
35
|
The 'var' array contains the axes along which the decision nodes operate.
|
|
36
36
|
|
|
37
|
-
The 'split' array contains the decision boundaries. The boundaries are open on the right, i.e., a point belongs to the left child iff x < split. Whether a node is a leaf is indicated by the corresponding 'split' element being 0.
|
|
37
|
+
The 'split' array contains the decision boundaries. The boundaries are open on the right, i.e., a point belongs to the left child iff x < split. Whether a node is a leaf is indicated by the corresponding 'split' element being 0. Unused nodes also have split set to 0.
|
|
38
38
|
|
|
39
39
|
Since the nodes at the bottom can only be leaves and not decision nodes, the 'var' and 'split' arrays have half the length of the 'leaf' array.
|
|
40
40
|
|
|
@@ -44,11 +44,12 @@ import functools
|
|
|
44
44
|
import math
|
|
45
45
|
|
|
46
46
|
import jax
|
|
47
|
-
from jax import numpy as jnp
|
|
48
47
|
from jax import lax
|
|
48
|
+
from jax import numpy as jnp
|
|
49
49
|
|
|
50
50
|
from . import jaxext
|
|
51
51
|
|
|
52
|
+
|
|
52
53
|
def make_tree(depth, dtype):
|
|
53
54
|
"""
|
|
54
55
|
Make an array to represent a binary tree.
|
|
@@ -66,7 +67,8 @@ def make_tree(depth, dtype):
|
|
|
66
67
|
tree : array
|
|
67
68
|
An array of zeroes with shape (2 ** depth,).
|
|
68
69
|
"""
|
|
69
|
-
return jnp.zeros(2
|
|
70
|
+
return jnp.zeros(2**depth, dtype)
|
|
71
|
+
|
|
70
72
|
|
|
71
73
|
def tree_depth(tree):
|
|
72
74
|
"""
|
|
@@ -85,6 +87,7 @@ def tree_depth(tree):
|
|
|
85
87
|
"""
|
|
86
88
|
return int(round(math.log2(tree.shape[-1])))
|
|
87
89
|
|
|
90
|
+
|
|
88
91
|
def traverse_tree(x, var_tree, split_tree):
|
|
89
92
|
"""
|
|
90
93
|
Find the leaf where a point falls into.
|
|
@@ -103,7 +106,6 @@ def traverse_tree(x, var_tree, split_tree):
|
|
|
103
106
|
index : int
|
|
104
107
|
The index of the leaf.
|
|
105
108
|
"""
|
|
106
|
-
|
|
107
109
|
carry = (
|
|
108
110
|
jnp.zeros((), bool),
|
|
109
111
|
jnp.ones((), jaxext.minimal_unsigned_dtype(2 * var_tree.size - 1)),
|
|
@@ -125,6 +127,7 @@ def traverse_tree(x, var_tree, split_tree):
|
|
|
125
127
|
(_, index), _ = lax.scan(loop, carry, None, depth, unroll=16)
|
|
126
128
|
return index
|
|
127
129
|
|
|
130
|
+
|
|
128
131
|
@functools.partial(jaxext.vmap_nodoc, in_axes=(None, 0, 0))
|
|
129
132
|
@functools.partial(jaxext.vmap_nodoc, in_axes=(1, None, None))
|
|
130
133
|
def traverse_forest(X, var_trees, split_trees):
|
|
@@ -147,6 +150,7 @@ def traverse_forest(X, var_trees, split_trees):
|
|
|
147
150
|
"""
|
|
148
151
|
return traverse_tree(X, var_trees, split_trees)
|
|
149
152
|
|
|
153
|
+
|
|
150
154
|
def evaluate_forest(X, leaf_trees, var_trees, split_trees, dtype=None, sum_trees=True):
|
|
151
155
|
"""
|
|
152
156
|
Evaluate a ensemble of trees at an array of points.
|
|
@@ -178,11 +182,12 @@ def evaluate_forest(X, leaf_trees, var_trees, split_trees, dtype=None, sum_trees
|
|
|
178
182
|
leaves = leaf_trees[tree_index[:, None], indices]
|
|
179
183
|
if sum_trees:
|
|
180
184
|
return jnp.sum(leaves, axis=0, dtype=dtype)
|
|
181
|
-
|
|
182
|
-
|
|
185
|
+
# this sum suggests to swap the vmaps, but I think it's better for X
|
|
186
|
+
# copying to keep it that way
|
|
183
187
|
else:
|
|
184
188
|
return leaves
|
|
185
189
|
|
|
190
|
+
|
|
186
191
|
def is_actual_leaf(split_tree, *, add_bottom_level=False):
|
|
187
192
|
"""
|
|
188
193
|
Return a mask indicating the leaf nodes in a tree.
|
|
@@ -211,6 +216,7 @@ def is_actual_leaf(split_tree, *, add_bottom_level=False):
|
|
|
211
216
|
parent_nonleaf = parent_nonleaf.at[1].set(True)
|
|
212
217
|
return is_leaf & parent_nonleaf
|
|
213
218
|
|
|
219
|
+
|
|
214
220
|
def is_leaves_parent(split_tree):
|
|
215
221
|
"""
|
|
216
222
|
Return a mask indicating the nodes with leaf (and only leaf) children.
|
|
@@ -225,14 +231,17 @@ def is_leaves_parent(split_tree):
|
|
|
225
231
|
is_leaves_parent : bool array (2 ** (d - 1),)
|
|
226
232
|
The mask indicating which nodes have leaf children.
|
|
227
233
|
"""
|
|
228
|
-
index = jnp.arange(
|
|
229
|
-
|
|
230
|
-
|
|
234
|
+
index = jnp.arange(
|
|
235
|
+
split_tree.size, dtype=jaxext.minimal_unsigned_dtype(2 * split_tree.size - 1)
|
|
236
|
+
)
|
|
237
|
+
left_index = index << 1 # left child
|
|
238
|
+
right_index = left_index + 1 # right child
|
|
231
239
|
left_leaf = split_tree.at[left_index].get(mode='fill', fill_value=0) == 0
|
|
232
240
|
right_leaf = split_tree.at[right_index].get(mode='fill', fill_value=0) == 0
|
|
233
241
|
is_not_leaf = split_tree.astype(bool)
|
|
234
242
|
return is_not_leaf & left_leaf & right_leaf
|
|
235
|
-
|
|
243
|
+
# the 0-th item has split == 0, so it's not counted
|
|
244
|
+
|
|
236
245
|
|
|
237
246
|
def tree_depths(tree_length):
|
|
238
247
|
"""
|
|
@@ -253,8 +262,49 @@ def tree_depths(tree_length):
|
|
|
253
262
|
depths = []
|
|
254
263
|
depth = 0
|
|
255
264
|
for i in range(tree_length):
|
|
256
|
-
if i == 2
|
|
265
|
+
if i == 2**depth:
|
|
257
266
|
depth += 1
|
|
258
267
|
depths.append(depth - 1)
|
|
259
268
|
depths[0] = 0
|
|
260
269
|
return jnp.array(depths, jaxext.minimal_unsigned_dtype(max(depths)))
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
def is_used(split_tree):
|
|
273
|
+
"""
|
|
274
|
+
Return a mask indicating the used nodes in a tree.
|
|
275
|
+
|
|
276
|
+
Parameters
|
|
277
|
+
----------
|
|
278
|
+
split_tree : int array (2 ** (d - 1),)
|
|
279
|
+
The decision boundaries of the tree.
|
|
280
|
+
|
|
281
|
+
Returns
|
|
282
|
+
-------
|
|
283
|
+
is_used : bool array (2 ** d,)
|
|
284
|
+
A mask indicating which nodes are actually used.
|
|
285
|
+
"""
|
|
286
|
+
internal_node = split_tree.astype(bool)
|
|
287
|
+
internal_node = jnp.concatenate([internal_node, jnp.zeros_like(internal_node)])
|
|
288
|
+
actual_leaf = is_actual_leaf(split_tree, add_bottom_level=True)
|
|
289
|
+
return internal_node | actual_leaf
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
def forest_fill(split_trees):
|
|
293
|
+
"""
|
|
294
|
+
Return the fraction of used nodes in a set of trees.
|
|
295
|
+
|
|
296
|
+
Parameters
|
|
297
|
+
----------
|
|
298
|
+
split_trees : array (m, 2 ** (d - 1),)
|
|
299
|
+
The decision boundaries of the trees.
|
|
300
|
+
|
|
301
|
+
Returns
|
|
302
|
+
-------
|
|
303
|
+
fill : float
|
|
304
|
+
The number of tree nodes in the forest over the maximum number that
|
|
305
|
+
could be stored in the arrays.
|
|
306
|
+
"""
|
|
307
|
+
m, _ = split_trees.shape
|
|
308
|
+
used = jax.vmap(is_used)(split_trees)
|
|
309
|
+
count = jnp.count_nonzero(used)
|
|
310
|
+
return count / (used.size - m)
|
bartz/jaxext.py
CHANGED
|
@@ -22,60 +22,74 @@
|
|
|
22
22
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
23
23
|
# SOFTWARE.
|
|
24
24
|
|
|
25
|
+
"""Additions to jax."""
|
|
26
|
+
|
|
25
27
|
import functools
|
|
26
28
|
import math
|
|
27
29
|
import warnings
|
|
28
30
|
|
|
29
|
-
from scipy import special
|
|
30
31
|
import jax
|
|
32
|
+
from jax import lax, random, tree_util
|
|
31
33
|
from jax import numpy as jnp
|
|
32
|
-
from
|
|
33
|
-
|
|
34
|
+
from scipy import special
|
|
35
|
+
|
|
34
36
|
|
|
35
37
|
def float_type(*args):
|
|
36
|
-
"""
|
|
37
|
-
Determine the jax floating point result type given operands/types.
|
|
38
|
-
"""
|
|
38
|
+
"""Determine the jax floating point result type given operands/types."""
|
|
39
39
|
t = jnp.result_type(*args)
|
|
40
40
|
return jnp.sin(jnp.empty(0, t)).dtype
|
|
41
41
|
|
|
42
|
-
|
|
42
|
+
|
|
43
|
+
def _castto(func, type):
|
|
43
44
|
@functools.wraps(func)
|
|
44
45
|
def newfunc(*args, **kw):
|
|
45
46
|
return func(*args, **kw).astype(type)
|
|
47
|
+
|
|
46
48
|
return newfunc
|
|
47
49
|
|
|
50
|
+
|
|
48
51
|
class scipy:
|
|
52
|
+
"""Mockup of the :external:py:mod:`scipy` module."""
|
|
49
53
|
|
|
50
54
|
class special:
|
|
55
|
+
"""Mockup of the :external:py:mod:`scipy.special` module."""
|
|
51
56
|
|
|
52
|
-
@
|
|
57
|
+
@staticmethod
|
|
53
58
|
def gammainccinv(a, y):
|
|
59
|
+
"""Survival function inverse of the Gamma(a, 1) distribution."""
|
|
54
60
|
a = jnp.asarray(a)
|
|
55
61
|
y = jnp.asarray(y)
|
|
56
62
|
shape = jnp.broadcast_shapes(a.shape, y.shape)
|
|
57
63
|
dtype = float_type(a.dtype, y.dtype)
|
|
58
64
|
dummy = jax.ShapeDtypeStruct(shape, dtype)
|
|
59
|
-
ufunc =
|
|
65
|
+
ufunc = _castto(special.gammainccinv, dtype)
|
|
60
66
|
return jax.pure_callback(ufunc, dummy, a, y, vmap_method='expand_dims')
|
|
61
67
|
|
|
62
68
|
class stats:
|
|
69
|
+
"""Mockup of the :external:py:mod:`scipy.stats` module."""
|
|
63
70
|
|
|
64
71
|
class invgamma:
|
|
72
|
+
"""Class that represents the distribution InvGamma(a, 1)."""
|
|
65
73
|
|
|
74
|
+
@staticmethod
|
|
66
75
|
def ppf(q, a):
|
|
76
|
+
"""Percentile point function."""
|
|
67
77
|
return 1 / scipy.special.gammainccinv(a, q)
|
|
68
78
|
|
|
69
|
-
|
|
79
|
+
|
|
70
80
|
def vmap_nodoc(fun, *args, **kw):
|
|
71
81
|
"""
|
|
72
|
-
|
|
82
|
+
Acts like `jax.vmap` but preserves the docstring of the function unchanged.
|
|
83
|
+
|
|
84
|
+
This is useful if the docstring already takes into account that the
|
|
85
|
+
arguments have additional axes due to vmap.
|
|
73
86
|
"""
|
|
74
87
|
doc = fun.__doc__
|
|
75
88
|
fun = jax.vmap(fun, *args, **kw)
|
|
76
89
|
fun.__doc__ = doc
|
|
77
90
|
return fun
|
|
78
91
|
|
|
92
|
+
|
|
79
93
|
def huge_value(x):
|
|
80
94
|
"""
|
|
81
95
|
Return the maximum value that can be stored in `x`.
|
|
@@ -95,23 +109,23 @@ def huge_value(x):
|
|
|
95
109
|
else:
|
|
96
110
|
return jnp.inf
|
|
97
111
|
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
Return the smallest unsigned integer dtype that can represent
|
|
101
|
-
|
|
102
|
-
"""
|
|
103
|
-
if max_value < 2 ** 8:
|
|
112
|
+
|
|
113
|
+
def minimal_unsigned_dtype(value):
|
|
114
|
+
"""Return the smallest unsigned integer dtype that can represent `value`."""
|
|
115
|
+
if value < 2**8:
|
|
104
116
|
return jnp.uint8
|
|
105
|
-
if
|
|
117
|
+
if value < 2**16:
|
|
106
118
|
return jnp.uint16
|
|
107
|
-
if
|
|
119
|
+
if value < 2**32:
|
|
108
120
|
return jnp.uint32
|
|
109
121
|
return jnp.uint64
|
|
110
122
|
|
|
123
|
+
|
|
111
124
|
def signed_to_unsigned(int_dtype):
|
|
112
125
|
"""
|
|
113
|
-
Map a signed integer type to its unsigned counterpart.
|
|
114
|
-
|
|
126
|
+
Map a signed integer type to its unsigned counterpart.
|
|
127
|
+
|
|
128
|
+
Unsigned types are passed through.
|
|
115
129
|
"""
|
|
116
130
|
assert jnp.issubdtype(int_dtype, jnp.integer)
|
|
117
131
|
if jnp.issubdtype(int_dtype, jnp.unsignedinteger):
|
|
@@ -125,12 +139,12 @@ def signed_to_unsigned(int_dtype):
|
|
|
125
139
|
if int_dtype == jnp.int64:
|
|
126
140
|
return jnp.uint64
|
|
127
141
|
|
|
142
|
+
|
|
128
143
|
def ensure_unsigned(x):
|
|
129
|
-
"""
|
|
130
|
-
If x has signed integer type, cast it to the unsigned dtype of the same size.
|
|
131
|
-
"""
|
|
144
|
+
"""If x has signed integer type, cast it to the unsigned dtype of the same size."""
|
|
132
145
|
return x.astype(signed_to_unsigned(x.dtype))
|
|
133
146
|
|
|
147
|
+
|
|
134
148
|
@functools.partial(jax.jit, static_argnums=(1,))
|
|
135
149
|
def unique(x, size, fill_value):
|
|
136
150
|
"""
|
|
@@ -158,15 +172,18 @@ def unique(x, size, fill_value):
|
|
|
158
172
|
if size == 0:
|
|
159
173
|
return jnp.empty(0, x.dtype), 0
|
|
160
174
|
x = jnp.sort(x)
|
|
175
|
+
|
|
161
176
|
def loop(carry, x):
|
|
162
177
|
i_out, i_in, last, out = carry
|
|
163
178
|
i_out = jnp.where(x == last, i_out, i_out + 1)
|
|
164
179
|
out = out.at[i_out].set(x)
|
|
165
180
|
return (i_out, i_in + 1, x, out), None
|
|
181
|
+
|
|
166
182
|
carry = 0, 0, x[0], jnp.full(size, fill_value, x.dtype)
|
|
167
183
|
(actual_length, _, _, out), _ = jax.lax.scan(loop, carry, x[:size])
|
|
168
184
|
return out, actual_length + 1
|
|
169
185
|
|
|
186
|
+
|
|
170
187
|
def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False):
|
|
171
188
|
"""
|
|
172
189
|
Batch a function such that each batch is smaller than a threshold.
|
|
@@ -203,6 +220,7 @@ def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False)
|
|
|
203
220
|
def check_no_nones(axes, tree):
|
|
204
221
|
def check_not_none(_, axis):
|
|
205
222
|
assert axis is not None
|
|
223
|
+
|
|
206
224
|
tree_util.tree_map(check_not_none, tree, axes)
|
|
207
225
|
|
|
208
226
|
def extract_size(axes, tree):
|
|
@@ -211,6 +229,7 @@ def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False)
|
|
|
211
229
|
return None
|
|
212
230
|
else:
|
|
213
231
|
return x.shape[axis]
|
|
232
|
+
|
|
214
233
|
sizes = tree_util.tree_map(get_size, tree, axes)
|
|
215
234
|
sizes, _ = tree_util.tree_flatten(sizes)
|
|
216
235
|
assert all(s == sizes[0] for s in sizes)
|
|
@@ -219,6 +238,7 @@ def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False)
|
|
|
219
238
|
def sum_nbytes(tree):
|
|
220
239
|
def nbytes(x):
|
|
221
240
|
return math.prod(x.shape) * x.dtype.itemsize
|
|
241
|
+
|
|
222
242
|
return tree_util.tree_reduce(lambda size, x: size + nbytes(x), tree, 0)
|
|
223
243
|
|
|
224
244
|
def next_divisor_small(dividend, min_divisor):
|
|
@@ -247,6 +267,7 @@ def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False)
|
|
|
247
267
|
return None
|
|
248
268
|
else:
|
|
249
269
|
return x
|
|
270
|
+
|
|
250
271
|
return tree_util.tree_map(pull_nonbatched, tree, axes), tree
|
|
251
272
|
|
|
252
273
|
def push_nonbatched(axes, tree, original_tree):
|
|
@@ -255,32 +276,38 @@ def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False)
|
|
|
255
276
|
return original_x
|
|
256
277
|
else:
|
|
257
278
|
return x
|
|
279
|
+
|
|
258
280
|
return tree_util.tree_map(push_nonbatched, original_tree, tree, axes)
|
|
259
281
|
|
|
260
282
|
def move_axes_out(axes, tree):
|
|
261
283
|
def move_axis_out(x, axis):
|
|
262
284
|
return jnp.moveaxis(x, axis, 0)
|
|
285
|
+
|
|
263
286
|
return tree_util.tree_map(move_axis_out, tree, axes)
|
|
264
287
|
|
|
265
288
|
def move_axes_in(axes, tree):
|
|
266
289
|
def move_axis_in(x, axis):
|
|
267
290
|
return jnp.moveaxis(x, 0, axis)
|
|
291
|
+
|
|
268
292
|
return tree_util.tree_map(move_axis_in, tree, axes)
|
|
269
293
|
|
|
270
294
|
def batch(tree, nbatches):
|
|
271
295
|
def batch(x):
|
|
272
296
|
return x.reshape((nbatches, x.shape[0] // nbatches) + x.shape[1:])
|
|
297
|
+
|
|
273
298
|
return tree_util.tree_map(batch, tree)
|
|
274
299
|
|
|
275
300
|
def unbatch(tree):
|
|
276
301
|
def unbatch(x):
|
|
277
302
|
return x.reshape((x.shape[0] * x.shape[1],) + x.shape[2:])
|
|
303
|
+
|
|
278
304
|
return tree_util.tree_map(unbatch, tree)
|
|
279
305
|
|
|
280
306
|
def check_same(tree1, tree2):
|
|
281
307
|
def check_same(x1, x2):
|
|
282
308
|
assert x1.shape == x2.shape
|
|
283
309
|
assert x1.dtype == x2.dtype
|
|
310
|
+
|
|
284
311
|
tree_util.tree_map(check_same, tree1, tree2)
|
|
285
312
|
|
|
286
313
|
initial_in_axes = in_axes
|
|
@@ -300,7 +327,9 @@ def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False)
|
|
|
300
327
|
args, nonbatched_args = pull_nonbatched(in_axes, args)
|
|
301
328
|
|
|
302
329
|
total_nbytes = sum_nbytes((args, example_result))
|
|
303
|
-
min_nbatches = total_nbytes // max_io_nbytes + bool(
|
|
330
|
+
min_nbatches = total_nbytes // max_io_nbytes + bool(
|
|
331
|
+
total_nbytes % max_io_nbytes
|
|
332
|
+
)
|
|
304
333
|
min_nbatches = max(1, min_nbatches)
|
|
305
334
|
nbatches = next_divisor(size, min_nbatches)
|
|
306
335
|
assert 1 <= nbatches <= max(1, size)
|
|
@@ -310,7 +339,9 @@ def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False)
|
|
|
310
339
|
batch_nbytes = total_nbytes // nbatches
|
|
311
340
|
if batch_nbytes > max_io_nbytes:
|
|
312
341
|
assert size == nbatches
|
|
313
|
-
warnings.warn(
|
|
342
|
+
warnings.warn(
|
|
343
|
+
f'batch_nbytes = {batch_nbytes} > max_io_nbytes = {max_io_nbytes}'
|
|
344
|
+
)
|
|
314
345
|
|
|
315
346
|
def loop(_, args):
|
|
316
347
|
args = move_axes_in(in_axes, args)
|
|
@@ -333,17 +364,60 @@ def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False)
|
|
|
333
364
|
|
|
334
365
|
return batched_func
|
|
335
366
|
|
|
336
|
-
@tree_util.register_pytree_node_class
|
|
337
|
-
class LeafDict(dict):
|
|
338
|
-
""" dictionary that acts as a leaf in jax pytrees, to store compile-time
|
|
339
|
-
values """
|
|
340
367
|
|
|
341
|
-
|
|
342
|
-
|
|
368
|
+
class split:
|
|
369
|
+
"""
|
|
370
|
+
Split a key into `num` keys.
|
|
343
371
|
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
372
|
+
Parameters
|
|
373
|
+
----------
|
|
374
|
+
key : jax.dtypes.prng_key array
|
|
375
|
+
The key to split.
|
|
376
|
+
num : int
|
|
377
|
+
The number of keys to split into.
|
|
378
|
+
"""
|
|
347
379
|
|
|
348
|
-
def
|
|
349
|
-
|
|
380
|
+
def __init__(self, key, num=2):
|
|
381
|
+
self._keys = random.split(key, num)
|
|
382
|
+
|
|
383
|
+
def __len__(self):
|
|
384
|
+
return self._keys.size
|
|
385
|
+
|
|
386
|
+
def pop(self, shape=None):
|
|
387
|
+
"""
|
|
388
|
+
Pop one or more keys from the list.
|
|
389
|
+
|
|
390
|
+
Parameters
|
|
391
|
+
----------
|
|
392
|
+
shape : int or tuple of int, optional
|
|
393
|
+
The shape of the keys to pop. If `None`, a single key is popped.
|
|
394
|
+
If an integer, that many keys are popped. If a tuple, the keys are
|
|
395
|
+
reshaped to that shape.
|
|
396
|
+
|
|
397
|
+
Returns
|
|
398
|
+
-------
|
|
399
|
+
keys : jax.dtypes.prng_key array
|
|
400
|
+
The popped keys.
|
|
401
|
+
|
|
402
|
+
Raises
|
|
403
|
+
------
|
|
404
|
+
IndexError
|
|
405
|
+
If `shape` is larger than the number of keys left in the list.
|
|
406
|
+
|
|
407
|
+
Notes
|
|
408
|
+
-----
|
|
409
|
+
The keys are popped from the beginning of the list, so for example
|
|
410
|
+
``list(keys.pop(2))`` is equivalent to ``[keys.pop(), keys.pop()]``.
|
|
411
|
+
"""
|
|
412
|
+
if shape is None:
|
|
413
|
+
shape = ()
|
|
414
|
+
elif not isinstance(shape, tuple):
|
|
415
|
+
shape = (shape,)
|
|
416
|
+
size_to_pop = math.prod(shape)
|
|
417
|
+
if size_to_pop > self._keys.size:
|
|
418
|
+
raise IndexError(
|
|
419
|
+
f'Cannot pop {size_to_pop} keys from {self._keys.size} keys'
|
|
420
|
+
)
|
|
421
|
+
popped_keys = self._keys[:size_to_pop]
|
|
422
|
+
self._keys = self._keys[size_to_pop:]
|
|
423
|
+
return popped_keys.reshape(shape)
|