bartz 0.1.0__tar.gz → 0.2.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {bartz-0.1.0 → bartz-0.2.1}/PKG-INFO +7 -1
- bartz-0.2.1/README.md +9 -0
- {bartz-0.1.0 → bartz-0.2.1}/pyproject.toml +2 -1
- bartz-0.1.0/src/bartz/interface.py → bartz-0.2.1/src/bartz/BART.py +4 -3
- {bartz-0.1.0 → bartz-0.2.1}/src/bartz/__init__.py +6 -1
- bartz-0.2.1/src/bartz/_version.py +1 -0
- {bartz-0.1.0 → bartz-0.2.1}/src/bartz/debug.py +5 -4
- {bartz-0.1.0 → bartz-0.2.1}/src/bartz/grove.py +36 -36
- bartz-0.2.1/src/bartz/jaxext.py +341 -0
- {bartz-0.1.0 → bartz-0.2.1}/src/bartz/mcmcloop.py +19 -11
- {bartz-0.1.0 → bartz-0.2.1}/src/bartz/mcmcstep.py +200 -73
- {bartz-0.1.0 → bartz-0.2.1}/src/bartz/prepcovars.py +25 -30
- bartz-0.1.0/README.md +0 -3
- bartz-0.1.0/src/bartz/_version.py +0 -1
- bartz-0.1.0/src/bartz/jaxext.py +0 -85
- {bartz-0.1.0 → bartz-0.2.1}/LICENSE +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: bartz
|
|
3
|
-
Version: 0.1
|
|
3
|
+
Version: 0.2.1
|
|
4
4
|
Summary: A JAX implementation of BART
|
|
5
5
|
Home-page: https://github.com/Gattocrucco/bartz
|
|
6
6
|
License: MIT
|
|
@@ -20,7 +20,13 @@ Project-URL: Bug Tracker, https://github.com/Gattocrucco/bartz/issues
|
|
|
20
20
|
Project-URL: Repository, https://github.com/Gattocrucco/bartz
|
|
21
21
|
Description-Content-Type: text/markdown
|
|
22
22
|
|
|
23
|
+
[](https://pypi.org/project/bartz/)
|
|
24
|
+
|
|
23
25
|
# BART vectoriZed
|
|
24
26
|
|
|
25
27
|
A branchless vectorized implementation of Bayesian Additive Regression Trees (BART) in JAX.
|
|
26
28
|
|
|
29
|
+
BART is a nonparametric Bayesian regression technique. Given predictors $X$ and responses $y$, BART finds a function to predict $y$ given $X$. The result of the inference is a sample of possible functions, representing the uncertainty over the determination of the function.
|
|
30
|
+
|
|
31
|
+
This Python module provides an implementation of BART that runs on GPU, to process large datasets faster. It is also a good on CPU. Most other implementations of BART are for R, and run on CPU only.
|
|
32
|
+
|
bartz-0.2.1/README.md
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
[](https://pypi.org/project/bartz/)
|
|
2
|
+
|
|
3
|
+
# BART vectoriZed
|
|
4
|
+
|
|
5
|
+
A branchless vectorized implementation of Bayesian Additive Regression Trees (BART) in JAX.
|
|
6
|
+
|
|
7
|
+
BART is a nonparametric Bayesian regression technique. Given predictors $X$ and responses $y$, BART finds a function to predict $y$ given $X$. The result of the inference is a sample of possible functions, representing the uncertainty over the determination of the function.
|
|
8
|
+
|
|
9
|
+
This Python module provides an implementation of BART that runs on GPU, to process large datasets faster. It is also a good on CPU. Most other implementations of BART are for R, and run on CPU only.
|
|
@@ -28,7 +28,7 @@ build-backend = "poetry.core.masonry.api"
|
|
|
28
28
|
|
|
29
29
|
[tool.poetry]
|
|
30
30
|
name = "bartz"
|
|
31
|
-
version = "0.1
|
|
31
|
+
version = "0.2.1"
|
|
32
32
|
description = "A JAX implementation of BART"
|
|
33
33
|
authors = ["Giacomo Petrillo <info@giacomopetrillo.com>"]
|
|
34
34
|
license = "MIT"
|
|
@@ -53,6 +53,7 @@ ipython = "^8.22.2"
|
|
|
53
53
|
matplotlib = "^3.8.3"
|
|
54
54
|
appnope = "^0.1.4"
|
|
55
55
|
tomli = "^2.0.1"
|
|
56
|
+
packaging = "^24.0"
|
|
56
57
|
|
|
57
58
|
[tool.poetry.group.test.dependencies]
|
|
58
59
|
coverage = "^7.4.3"
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# bartz/src/bartz/
|
|
1
|
+
# bartz/src/bartz/BART.py
|
|
2
2
|
#
|
|
3
3
|
# Copyright (c) 2024, Giacomo Petrillo
|
|
4
4
|
#
|
|
@@ -33,7 +33,7 @@ from . import mcmcstep
|
|
|
33
33
|
from . import mcmcloop
|
|
34
34
|
from . import prepcovars
|
|
35
35
|
|
|
36
|
-
class
|
|
36
|
+
class gbart:
|
|
37
37
|
"""
|
|
38
38
|
Nonparametric regression with Bayesian Additive Regression Trees (BART).
|
|
39
39
|
|
|
@@ -133,7 +133,7 @@ class BART:
|
|
|
133
133
|
|
|
134
134
|
Notes
|
|
135
135
|
-----
|
|
136
|
-
This interface imitates the function `
|
|
136
|
+
This interface imitates the function `gbart` from the R package `BART
|
|
137
137
|
<https://cran.r-project.org/package=BART>`_, but with these differences:
|
|
138
138
|
|
|
139
139
|
- If `x_train` and `x_test` are matrices, they have one predictor per row
|
|
@@ -142,6 +142,7 @@ class BART:
|
|
|
142
142
|
- `usequants` is always `True`.
|
|
143
143
|
- `rm_const` is always `False`.
|
|
144
144
|
- The default `numcut` is 255 instead of 100.
|
|
145
|
+
- A lot of functionality is missing (variable selection, discrete response).
|
|
145
146
|
- There are some additional attributes, and some missing.
|
|
146
147
|
"""
|
|
147
148
|
|
|
@@ -30,6 +30,11 @@ See the manual at https://gattocrucco.github.io/bartz/docs
|
|
|
30
30
|
|
|
31
31
|
from ._version import __version__
|
|
32
32
|
|
|
33
|
-
from .
|
|
33
|
+
from . import BART
|
|
34
34
|
|
|
35
35
|
from . import debug
|
|
36
|
+
from . import grove
|
|
37
|
+
from . import mcmcstep
|
|
38
|
+
from . import mcmcloop
|
|
39
|
+
from . import prepcovars
|
|
40
|
+
from . import jaxext
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = '0.2.1'
|
|
@@ -6,6 +6,7 @@ from jax import lax
|
|
|
6
6
|
|
|
7
7
|
from . import grove
|
|
8
8
|
from . import mcmcstep
|
|
9
|
+
from . import jaxext
|
|
9
10
|
|
|
10
11
|
def print_tree(leaf_tree, var_tree, split_tree, print_all=False):
|
|
11
12
|
|
|
@@ -83,7 +84,7 @@ def trace_depth_distr(split_trees_trace):
|
|
|
83
84
|
def points_per_leaf_distr(var_tree, split_tree, X):
|
|
84
85
|
traverse_tree = jax.vmap(grove.traverse_tree, in_axes=(1, None, None))
|
|
85
86
|
indices = traverse_tree(X, var_tree, split_tree)
|
|
86
|
-
count_tree = jnp.zeros(2 * split_tree.size, dtype=
|
|
87
|
+
count_tree = jnp.zeros(2 * split_tree.size, dtype=jaxext.minimal_unsigned_dtype(indices.size))
|
|
87
88
|
count_tree = count_tree.at[indices].add(1)
|
|
88
89
|
is_leaf = grove.is_actual_leaf(split_tree, add_bottom_level=True).view(jnp.uint8)
|
|
89
90
|
return jnp.bincount(count_tree, is_leaf, length=X.shape[1] + 1)
|
|
@@ -103,7 +104,7 @@ def trace_points_per_leaf_distr(bart, X):
|
|
|
103
104
|
return distr
|
|
104
105
|
|
|
105
106
|
def check_types(leaf_tree, var_tree, split_tree, max_split):
|
|
106
|
-
expected_var_dtype =
|
|
107
|
+
expected_var_dtype = jaxext.minimal_unsigned_dtype(max_split.size - 1)
|
|
107
108
|
expected_split_dtype = max_split.dtype
|
|
108
109
|
return var_tree.dtype == expected_var_dtype and split_tree.dtype == expected_split_dtype
|
|
109
110
|
|
|
@@ -117,7 +118,7 @@ def check_leaf_values(leaf_tree, var_tree, split_tree, max_split):
|
|
|
117
118
|
return jnp.all(jnp.isfinite(leaf_tree))
|
|
118
119
|
|
|
119
120
|
def check_stray_nodes(leaf_tree, var_tree, split_tree, max_split):
|
|
120
|
-
index = jnp.arange(2 * split_tree.size, dtype=
|
|
121
|
+
index = jnp.arange(2 * split_tree.size, dtype=jaxext.minimal_unsigned_dtype(2 * split_tree.size - 1))
|
|
121
122
|
parent_index = index >> 1
|
|
122
123
|
is_not_leaf = split_tree.at[index].get(mode='fill', fill_value=0) != 0
|
|
123
124
|
parent_is_leaf = split_tree[parent_index] == 0
|
|
@@ -134,7 +135,7 @@ check_functions = [
|
|
|
134
135
|
]
|
|
135
136
|
|
|
136
137
|
def check_tree(leaf_tree, var_tree, split_tree, max_split):
|
|
137
|
-
error_type =
|
|
138
|
+
error_type = jaxext.minimal_unsigned_dtype(2 ** len(check_functions) - 1)
|
|
138
139
|
error = error_type(0)
|
|
139
140
|
for i, func in enumerate(check_functions):
|
|
140
141
|
ok = func(leaf_tree, var_tree, split_tree, max_split)
|
|
@@ -44,7 +44,6 @@ import functools
|
|
|
44
44
|
import math
|
|
45
45
|
|
|
46
46
|
import jax
|
|
47
|
-
|
|
48
47
|
from jax import numpy as jnp
|
|
49
48
|
from jax import lax
|
|
50
49
|
|
|
@@ -107,29 +106,47 @@ def traverse_tree(x, var_tree, split_tree):
|
|
|
107
106
|
|
|
108
107
|
carry = (
|
|
109
108
|
jnp.zeros((), bool),
|
|
110
|
-
jnp.ones((), minimal_unsigned_dtype(2 * var_tree.size - 1)),
|
|
109
|
+
jnp.ones((), jaxext.minimal_unsigned_dtype(2 * var_tree.size - 1)),
|
|
111
110
|
)
|
|
112
111
|
|
|
113
112
|
def loop(carry, _):
|
|
114
113
|
leaf_found, index = carry
|
|
115
114
|
|
|
116
|
-
split = split_tree
|
|
117
|
-
var = var_tree
|
|
115
|
+
split = split_tree[index]
|
|
116
|
+
var = var_tree[index]
|
|
118
117
|
|
|
119
|
-
leaf_found |=
|
|
118
|
+
leaf_found |= split == 0
|
|
120
119
|
child_index = (index << 1) + (x[var] >= split)
|
|
121
120
|
index = jnp.where(leaf_found, index, child_index)
|
|
122
121
|
|
|
123
122
|
return (leaf_found, index), None
|
|
124
123
|
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
# - separate and special-case the last iteration
|
|
128
|
-
|
|
129
|
-
depth = 1 + tree_depth(var_tree)
|
|
130
|
-
(_, index), _ = lax.scan(loop, carry, None, depth)
|
|
124
|
+
depth = tree_depth(var_tree)
|
|
125
|
+
(_, index), _ = lax.scan(loop, carry, None, depth, unroll=16)
|
|
131
126
|
return index
|
|
132
127
|
|
|
128
|
+
@functools.partial(jaxext.vmap_nodoc, in_axes=(None, 0, 0))
|
|
129
|
+
@functools.partial(jaxext.vmap_nodoc, in_axes=(1, None, None))
|
|
130
|
+
def traverse_forest(X, var_trees, split_trees):
|
|
131
|
+
"""
|
|
132
|
+
Find the leaves where points fall into.
|
|
133
|
+
|
|
134
|
+
Parameters
|
|
135
|
+
----------
|
|
136
|
+
X : array (p, n)
|
|
137
|
+
The coordinates to evaluate the trees at.
|
|
138
|
+
var_trees : array (m, 2 ** (d - 1))
|
|
139
|
+
The decision axes of the trees.
|
|
140
|
+
split_trees : array (m, 2 ** (d - 1))
|
|
141
|
+
The decision boundaries of the trees.
|
|
142
|
+
|
|
143
|
+
Returns
|
|
144
|
+
-------
|
|
145
|
+
indices : array (m, n)
|
|
146
|
+
The indices of the leaves.
|
|
147
|
+
"""
|
|
148
|
+
return traverse_tree(X, var_trees, split_trees)
|
|
149
|
+
|
|
133
150
|
def evaluate_forest(X, leaf_trees, var_trees, split_trees, dtype):
|
|
134
151
|
"""
|
|
135
152
|
Evaluate a ensemble of trees at an array of points.
|
|
@@ -138,7 +155,7 @@ def evaluate_forest(X, leaf_trees, var_trees, split_trees, dtype):
|
|
|
138
155
|
----------
|
|
139
156
|
X : array (p, n)
|
|
140
157
|
The coordinates to evaluate the trees at.
|
|
141
|
-
leaf_trees : (m, 2 ** d)
|
|
158
|
+
leaf_trees : array (m, 2 ** d)
|
|
142
159
|
The leaf values of the tree or forest. If the input is a forest, the
|
|
143
160
|
first axis is the tree index, and the values are summed.
|
|
144
161
|
var_trees : array (m, 2 ** (d - 1))
|
|
@@ -153,30 +170,13 @@ def evaluate_forest(X, leaf_trees, var_trees, split_trees, dtype):
|
|
|
153
170
|
out : array (n,)
|
|
154
171
|
The sum of the values of the trees at the points in `X`.
|
|
155
172
|
"""
|
|
156
|
-
indices =
|
|
173
|
+
indices = traverse_forest(X, var_trees, split_trees)
|
|
157
174
|
ntree, _ = leaf_trees.shape
|
|
158
|
-
tree_index = jnp.arange(ntree, dtype=minimal_unsigned_dtype(ntree - 1))[:, None]
|
|
175
|
+
tree_index = jnp.arange(ntree, dtype=jaxext.minimal_unsigned_dtype(ntree - 1))[:, None]
|
|
159
176
|
leaves = leaf_trees[tree_index, indices]
|
|
160
177
|
return jnp.sum(leaves, axis=0, dtype=dtype)
|
|
161
|
-
# this sum suggests to swap the vmaps, but I think it's better for X
|
|
162
|
-
|
|
163
|
-
@functools.partial(jax.vmap, in_axes=(None, 0, 0))
|
|
164
|
-
@functools.partial(jax.vmap, in_axes=(1, None, None))
|
|
165
|
-
def _traverse_forest(X, var_trees, split_trees):
|
|
166
|
-
return traverse_tree(X, var_trees, split_trees)
|
|
167
|
-
|
|
168
|
-
def minimal_unsigned_dtype(max_value):
|
|
169
|
-
"""
|
|
170
|
-
Return the smallest unsigned integer dtype that can represent a given
|
|
171
|
-
maximum value.
|
|
172
|
-
"""
|
|
173
|
-
if max_value < 2 ** 8:
|
|
174
|
-
return jnp.uint8
|
|
175
|
-
if max_value < 2 ** 16:
|
|
176
|
-
return jnp.uint16
|
|
177
|
-
if max_value < 2 ** 32:
|
|
178
|
-
return jnp.uint32
|
|
179
|
-
return jnp.uint64
|
|
178
|
+
# this sum suggests to swap the vmaps, but I think it's better for X
|
|
179
|
+
# copying to keep it that way
|
|
180
180
|
|
|
181
181
|
def is_actual_leaf(split_tree, *, add_bottom_level=False):
|
|
182
182
|
"""
|
|
@@ -200,7 +200,7 @@ def is_actual_leaf(split_tree, *, add_bottom_level=False):
|
|
|
200
200
|
if add_bottom_level:
|
|
201
201
|
size *= 2
|
|
202
202
|
is_leaf = jnp.concatenate([is_leaf, jnp.ones_like(is_leaf)])
|
|
203
|
-
index = jnp.arange(size, dtype=minimal_unsigned_dtype(size - 1))
|
|
203
|
+
index = jnp.arange(size, dtype=jaxext.minimal_unsigned_dtype(size - 1))
|
|
204
204
|
parent_index = index >> 1
|
|
205
205
|
parent_nonleaf = split_tree[parent_index].astype(bool)
|
|
206
206
|
parent_nonleaf = parent_nonleaf.at[1].set(True)
|
|
@@ -220,7 +220,7 @@ def is_leaves_parent(split_tree):
|
|
|
220
220
|
is_leaves_parent : bool array (2 ** (d - 1),)
|
|
221
221
|
The mask indicating which nodes have leaf children.
|
|
222
222
|
"""
|
|
223
|
-
index = jnp.arange(split_tree.size, dtype=minimal_unsigned_dtype(2 * split_tree.size - 1))
|
|
223
|
+
index = jnp.arange(split_tree.size, dtype=jaxext.minimal_unsigned_dtype(2 * split_tree.size - 1))
|
|
224
224
|
left_index = index << 1 # left child
|
|
225
225
|
right_index = left_index + 1 # right child
|
|
226
226
|
left_leaf = split_tree.at[left_index].get(mode='fill', fill_value=0) == 0
|
|
@@ -252,4 +252,4 @@ def tree_depths(tree_length):
|
|
|
252
252
|
depth += 1
|
|
253
253
|
depths.append(depth - 1)
|
|
254
254
|
depths[0] = 0
|
|
255
|
-
return jnp.array(depths, minimal_unsigned_dtype(max(depths)))
|
|
255
|
+
return jnp.array(depths, jaxext.minimal_unsigned_dtype(max(depths)))
|
|
@@ -0,0 +1,341 @@
|
|
|
1
|
+
# bartz/src/bartz/jaxext.py
|
|
2
|
+
#
|
|
3
|
+
# Copyright (c) 2024, Giacomo Petrillo
|
|
4
|
+
#
|
|
5
|
+
# This file is part of bartz.
|
|
6
|
+
#
|
|
7
|
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
8
|
+
# of this software and associated documentation files (the "Software"), to deal
|
|
9
|
+
# in the Software without restriction, including without limitation the rights
|
|
10
|
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
11
|
+
# copies of the Software, and to permit persons to whom the Software is
|
|
12
|
+
# furnished to do so, subject to the following conditions:
|
|
13
|
+
#
|
|
14
|
+
# The above copyright notice and this permission notice shall be included in all
|
|
15
|
+
# copies or substantial portions of the Software.
|
|
16
|
+
#
|
|
17
|
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
18
|
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
19
|
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
20
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
21
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
22
|
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
23
|
+
# SOFTWARE.
|
|
24
|
+
|
|
25
|
+
import functools
|
|
26
|
+
import math
|
|
27
|
+
import warnings
|
|
28
|
+
|
|
29
|
+
from scipy import special
|
|
30
|
+
import jax
|
|
31
|
+
from jax import numpy as jnp
|
|
32
|
+
from jax import tree_util
|
|
33
|
+
from jax import lax
|
|
34
|
+
|
|
35
|
+
def float_type(*args):
|
|
36
|
+
"""
|
|
37
|
+
Determine the jax floating point result type given operands/types.
|
|
38
|
+
"""
|
|
39
|
+
t = jnp.result_type(*args)
|
|
40
|
+
return jnp.sin(jnp.empty(0, t)).dtype
|
|
41
|
+
|
|
42
|
+
def castto(func, type):
|
|
43
|
+
@functools.wraps(func)
|
|
44
|
+
def newfunc(*args, **kw):
|
|
45
|
+
return func(*args, **kw).astype(type)
|
|
46
|
+
return newfunc
|
|
47
|
+
|
|
48
|
+
def pure_callback_ufunc(callback, dtype, *args, excluded=None, **kwargs):
|
|
49
|
+
""" version of `jax.pure_callback` that deals correctly with ufuncs,
|
|
50
|
+
see `<https://github.com/google/jax/issues/17187>`_ """
|
|
51
|
+
if excluded is None:
|
|
52
|
+
excluded = ()
|
|
53
|
+
shape = jnp.broadcast_shapes(*(
|
|
54
|
+
a.shape
|
|
55
|
+
for i, a in enumerate(args)
|
|
56
|
+
if i not in excluded
|
|
57
|
+
))
|
|
58
|
+
ndim = len(shape)
|
|
59
|
+
padded_args = [
|
|
60
|
+
a if i in excluded
|
|
61
|
+
else jnp.expand_dims(a, tuple(range(ndim - a.ndim)))
|
|
62
|
+
for i, a in enumerate(args)
|
|
63
|
+
]
|
|
64
|
+
result = jax.ShapeDtypeStruct(shape, dtype)
|
|
65
|
+
return jax.pure_callback(callback, result, *padded_args, vectorized=True, **kwargs)
|
|
66
|
+
|
|
67
|
+
# TODO when jax solves this, check version and piggyback on original if new
|
|
68
|
+
|
|
69
|
+
class scipy:
|
|
70
|
+
|
|
71
|
+
class special:
|
|
72
|
+
|
|
73
|
+
@functools.wraps(special.gammainccinv)
|
|
74
|
+
def gammainccinv(a, y):
|
|
75
|
+
a = jnp.asarray(a)
|
|
76
|
+
y = jnp.asarray(y)
|
|
77
|
+
dtype = float_type(a.dtype, y.dtype)
|
|
78
|
+
ufunc = castto(special.gammainccinv, dtype)
|
|
79
|
+
return pure_callback_ufunc(ufunc, dtype, a, y)
|
|
80
|
+
|
|
81
|
+
class stats:
|
|
82
|
+
|
|
83
|
+
class invgamma:
|
|
84
|
+
|
|
85
|
+
def ppf(q, a):
|
|
86
|
+
return 1 / scipy.special.gammainccinv(a, q)
|
|
87
|
+
|
|
88
|
+
@functools.wraps(jax.vmap)
|
|
89
|
+
def vmap_nodoc(fun, *args, **kw):
|
|
90
|
+
"""
|
|
91
|
+
Version of `jax.vmap` that preserves the docstring of the input function.
|
|
92
|
+
"""
|
|
93
|
+
doc = fun.__doc__
|
|
94
|
+
fun = jax.vmap(fun, *args, **kw)
|
|
95
|
+
fun.__doc__ = doc
|
|
96
|
+
return fun
|
|
97
|
+
|
|
98
|
+
def huge_value(x):
|
|
99
|
+
"""
|
|
100
|
+
Return the maximum value that can be stored in `x`.
|
|
101
|
+
|
|
102
|
+
Parameters
|
|
103
|
+
----------
|
|
104
|
+
x : array
|
|
105
|
+
A numerical numpy or jax array.
|
|
106
|
+
|
|
107
|
+
Returns
|
|
108
|
+
-------
|
|
109
|
+
maxval : scalar
|
|
110
|
+
The maximum value allowed by `x`'s type (+inf for floats).
|
|
111
|
+
"""
|
|
112
|
+
if jnp.issubdtype(x.dtype, jnp.integer):
|
|
113
|
+
return jnp.iinfo(x.dtype).max
|
|
114
|
+
else:
|
|
115
|
+
return jnp.inf
|
|
116
|
+
|
|
117
|
+
def minimal_unsigned_dtype(max_value):
|
|
118
|
+
"""
|
|
119
|
+
Return the smallest unsigned integer dtype that can represent a given
|
|
120
|
+
maximum value (inclusive).
|
|
121
|
+
"""
|
|
122
|
+
if max_value < 2 ** 8:
|
|
123
|
+
return jnp.uint8
|
|
124
|
+
if max_value < 2 ** 16:
|
|
125
|
+
return jnp.uint16
|
|
126
|
+
if max_value < 2 ** 32:
|
|
127
|
+
return jnp.uint32
|
|
128
|
+
return jnp.uint64
|
|
129
|
+
|
|
130
|
+
def signed_to_unsigned(int_dtype):
|
|
131
|
+
"""
|
|
132
|
+
Map a signed integer type to its unsigned counterpart. Unsigned types are
|
|
133
|
+
passed through.
|
|
134
|
+
"""
|
|
135
|
+
assert jnp.issubdtype(int_dtype, jnp.integer)
|
|
136
|
+
if jnp.issubdtype(int_dtype, jnp.unsignedinteger):
|
|
137
|
+
return int_dtype
|
|
138
|
+
if int_dtype == jnp.int8:
|
|
139
|
+
return jnp.uint8
|
|
140
|
+
if int_dtype == jnp.int16:
|
|
141
|
+
return jnp.uint16
|
|
142
|
+
if int_dtype == jnp.int32:
|
|
143
|
+
return jnp.uint32
|
|
144
|
+
if int_dtype == jnp.int64:
|
|
145
|
+
return jnp.uint64
|
|
146
|
+
|
|
147
|
+
def ensure_unsigned(x):
|
|
148
|
+
"""
|
|
149
|
+
If x has signed integer type, cast it to the unsigned dtype of the same size.
|
|
150
|
+
"""
|
|
151
|
+
return x.astype(signed_to_unsigned(x.dtype))
|
|
152
|
+
|
|
153
|
+
@functools.partial(jax.jit, static_argnums=(1,))
|
|
154
|
+
def unique(x, size, fill_value):
|
|
155
|
+
"""
|
|
156
|
+
Restricted version of `jax.numpy.unique` that uses less memory.
|
|
157
|
+
|
|
158
|
+
Parameters
|
|
159
|
+
----------
|
|
160
|
+
x : 1d array
|
|
161
|
+
The input array.
|
|
162
|
+
size : int
|
|
163
|
+
The length of the output.
|
|
164
|
+
fill_value : scalar
|
|
165
|
+
The value to fill the output with if `size` is greater than the number
|
|
166
|
+
of unique values in `x`.
|
|
167
|
+
|
|
168
|
+
Returns
|
|
169
|
+
-------
|
|
170
|
+
out : array (size,)
|
|
171
|
+
The unique values in `x`, sorted, and right-padded with `fill_value`.
|
|
172
|
+
actual_length : int
|
|
173
|
+
The number of used values in `out`.
|
|
174
|
+
"""
|
|
175
|
+
if x.size == 0:
|
|
176
|
+
return jnp.full(size, fill_value, x.dtype), 0
|
|
177
|
+
if size == 0:
|
|
178
|
+
return jnp.empty(0, x.dtype), 0
|
|
179
|
+
x = jnp.sort(x)
|
|
180
|
+
def loop(carry, x):
|
|
181
|
+
i_out, i_in, last, out = carry
|
|
182
|
+
i_out = jnp.where(x == last, i_out, i_out + 1)
|
|
183
|
+
out = out.at[i_out].set(x)
|
|
184
|
+
return (i_out, i_in + 1, x, out), None
|
|
185
|
+
carry = 0, 0, x[0], jnp.full(size, fill_value, x.dtype)
|
|
186
|
+
(actual_length, _, _, out), _ = jax.lax.scan(loop, carry, x[:size])
|
|
187
|
+
return out, actual_length + 1
|
|
188
|
+
|
|
189
|
+
def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False):
|
|
190
|
+
"""
|
|
191
|
+
Batch a function such that each batch is smaller than a threshold.
|
|
192
|
+
|
|
193
|
+
Parameters
|
|
194
|
+
----------
|
|
195
|
+
func : callable
|
|
196
|
+
A jittable function with positional arguments only, with inputs and
|
|
197
|
+
outputs pytrees of arrays.
|
|
198
|
+
max_io_nbytes : int
|
|
199
|
+
The maximum number of input + output bytes in each batch.
|
|
200
|
+
in_axes : pytree of ints, default 0
|
|
201
|
+
A tree matching the structure of the function input, indicating along
|
|
202
|
+
which axes each array should be batched. If a single integer, it is
|
|
203
|
+
used for all arrays.
|
|
204
|
+
out_axes : pytree of ints, default 0
|
|
205
|
+
The same for outputs.
|
|
206
|
+
return_nbatches : bool, default False
|
|
207
|
+
If True, the number of batches is returned as a second output.
|
|
208
|
+
|
|
209
|
+
Returns
|
|
210
|
+
-------
|
|
211
|
+
batched_func : callable
|
|
212
|
+
A function with the same signature as `func`, but that processes the
|
|
213
|
+
input and output in batches in a loop.
|
|
214
|
+
"""
|
|
215
|
+
|
|
216
|
+
def expand_axes(axes, tree):
|
|
217
|
+
if isinstance(axes, int):
|
|
218
|
+
return tree_util.tree_map(lambda _: axes, tree)
|
|
219
|
+
return tree_util.tree_map(lambda _, axis: axis, tree, axes)
|
|
220
|
+
|
|
221
|
+
def extract_size(axes, tree):
|
|
222
|
+
sizes = tree_util.tree_map(lambda x, axis: x.shape[axis], tree, axes)
|
|
223
|
+
sizes, _ = tree_util.tree_flatten(sizes)
|
|
224
|
+
assert all(s == sizes[0] for s in sizes)
|
|
225
|
+
return sizes[0]
|
|
226
|
+
|
|
227
|
+
def sum_nbytes(tree):
|
|
228
|
+
def nbytes(x):
|
|
229
|
+
return math.prod(x.shape) * x.dtype.itemsize
|
|
230
|
+
return tree_util.tree_reduce(lambda size, x: size + nbytes(x), tree, 0)
|
|
231
|
+
|
|
232
|
+
def next_divisor_small(dividend, min_divisor):
|
|
233
|
+
for divisor in range(min_divisor, int(math.sqrt(dividend)) + 1):
|
|
234
|
+
if dividend % divisor == 0:
|
|
235
|
+
return divisor
|
|
236
|
+
return dividend
|
|
237
|
+
|
|
238
|
+
def next_divisor_large(dividend, min_divisor):
|
|
239
|
+
max_inv_divisor = dividend // min_divisor
|
|
240
|
+
for inv_divisor in range(max_inv_divisor, 0, -1):
|
|
241
|
+
if dividend % inv_divisor == 0:
|
|
242
|
+
return dividend // inv_divisor
|
|
243
|
+
return dividend
|
|
244
|
+
|
|
245
|
+
def next_divisor(dividend, min_divisor):
|
|
246
|
+
if min_divisor * min_divisor <= dividend:
|
|
247
|
+
return next_divisor_small(dividend, min_divisor)
|
|
248
|
+
return next_divisor_large(dividend, min_divisor)
|
|
249
|
+
|
|
250
|
+
def move_axes_out(axes, tree):
|
|
251
|
+
def move_axis_out(axis, x):
|
|
252
|
+
if axis != 0:
|
|
253
|
+
return jnp.moveaxis(x, axis, 0)
|
|
254
|
+
return x
|
|
255
|
+
return tree_util.tree_map(move_axis_out, axes, tree)
|
|
256
|
+
|
|
257
|
+
def move_axes_in(axes, tree):
|
|
258
|
+
def move_axis_in(axis, x):
|
|
259
|
+
if axis != 0:
|
|
260
|
+
return jnp.moveaxis(x, 0, axis)
|
|
261
|
+
return x
|
|
262
|
+
return tree_util.tree_map(move_axis_in, axes, tree)
|
|
263
|
+
|
|
264
|
+
def batch(tree, nbatches):
|
|
265
|
+
def batch(x):
|
|
266
|
+
return x.reshape((nbatches, x.shape[0] // nbatches) + x.shape[1:])
|
|
267
|
+
return tree_util.tree_map(batch, tree)
|
|
268
|
+
|
|
269
|
+
def unbatch(tree):
|
|
270
|
+
def unbatch(x):
|
|
271
|
+
return x.reshape((x.shape[0] * x.shape[1],) + x.shape[2:])
|
|
272
|
+
return tree_util.tree_map(unbatch, tree)
|
|
273
|
+
|
|
274
|
+
def check_same(tree1, tree2):
|
|
275
|
+
def check_same(x1, x2):
|
|
276
|
+
assert x1.shape == x2.shape
|
|
277
|
+
assert x1.dtype == x2.dtype
|
|
278
|
+
tree_util.tree_map(check_same, tree1, tree2)
|
|
279
|
+
|
|
280
|
+
initial_in_axes = in_axes
|
|
281
|
+
initial_out_axes = out_axes
|
|
282
|
+
|
|
283
|
+
@jax.jit
|
|
284
|
+
@functools.wraps(func)
|
|
285
|
+
def batched_func(*args):
|
|
286
|
+
example_result = jax.eval_shape(func, *args)
|
|
287
|
+
|
|
288
|
+
in_axes = expand_axes(initial_in_axes, args)
|
|
289
|
+
out_axes = expand_axes(initial_out_axes, example_result)
|
|
290
|
+
|
|
291
|
+
in_size = extract_size(in_axes, args)
|
|
292
|
+
out_size = extract_size(out_axes, example_result)
|
|
293
|
+
assert in_size == out_size
|
|
294
|
+
size = in_size
|
|
295
|
+
|
|
296
|
+
total_nbytes = sum_nbytes(args) + sum_nbytes(example_result)
|
|
297
|
+
min_nbatches = total_nbytes // max_io_nbytes + bool(total_nbytes % max_io_nbytes)
|
|
298
|
+
nbatches = next_divisor(size, min_nbatches)
|
|
299
|
+
assert 1 <= nbatches <= size
|
|
300
|
+
assert size % nbatches == 0
|
|
301
|
+
assert total_nbytes % nbatches == 0
|
|
302
|
+
|
|
303
|
+
batch_nbytes = total_nbytes // nbatches
|
|
304
|
+
if batch_nbytes > max_io_nbytes:
|
|
305
|
+
assert size == nbatches
|
|
306
|
+
warnings.warn(f'batch_nbytes = {batch_nbytes} > max_io_nbytes = {max_io_nbytes}')
|
|
307
|
+
|
|
308
|
+
def loop(_, args):
|
|
309
|
+
args = move_axes_in(in_axes, args)
|
|
310
|
+
result = func(*args)
|
|
311
|
+
result = move_axes_out(out_axes, result)
|
|
312
|
+
return None, result
|
|
313
|
+
|
|
314
|
+
args = move_axes_out(in_axes, args)
|
|
315
|
+
args = batch(args, nbatches)
|
|
316
|
+
_, result = lax.scan(loop, None, args)
|
|
317
|
+
result = unbatch(result)
|
|
318
|
+
result = move_axes_in(out_axes, result)
|
|
319
|
+
|
|
320
|
+
check_same(example_result, result)
|
|
321
|
+
|
|
322
|
+
if return_nbatches:
|
|
323
|
+
return result, nbatches
|
|
324
|
+
return result
|
|
325
|
+
|
|
326
|
+
return batched_func
|
|
327
|
+
|
|
328
|
+
@tree_util.register_pytree_node_class
|
|
329
|
+
class LeafDict(dict):
|
|
330
|
+
""" dictionary that acts as a leaf in jax pytrees, to store compile-time
|
|
331
|
+
values """
|
|
332
|
+
|
|
333
|
+
def tree_flatten(self):
|
|
334
|
+
return (), self
|
|
335
|
+
|
|
336
|
+
@classmethod
|
|
337
|
+
def tree_unflatten(cls, aux_data, children):
|
|
338
|
+
return aux_data
|
|
339
|
+
|
|
340
|
+
def __repr__(self):
|
|
341
|
+
return f'{__class__.__name__}({super().__repr__()})'
|
|
@@ -52,7 +52,7 @@ def run_mcmc(bart, n_burn, n_save, n_skip, callback, key):
|
|
|
52
52
|
n_save : int
|
|
53
53
|
The number of iterations to save.
|
|
54
54
|
n_skip : int
|
|
55
|
-
The number of iterations to skip between each saved iteration.
|
|
55
|
+
The number of iterations to skip between each saved iteration, plus 1.
|
|
56
56
|
callback : callable
|
|
57
57
|
An arbitrary function run at each iteration, called with the following
|
|
58
58
|
arguments, passed by keyword:
|
|
@@ -105,16 +105,19 @@ def run_mcmc(bart, n_burn, n_save, n_skip, callback, key):
|
|
|
105
105
|
output = {key: bart[key] for key in tracelist}
|
|
106
106
|
return (bart, i_total + 1, i_skip + 1, key), output
|
|
107
107
|
|
|
108
|
+
def empty_trace(bart, tracelist):
|
|
109
|
+
return {
|
|
110
|
+
key: jnp.empty((0,) + bart[key].shape, bart[key].dtype)
|
|
111
|
+
for key in tracelist
|
|
112
|
+
}
|
|
113
|
+
|
|
108
114
|
if n_burn > 0:
|
|
109
115
|
carry = bart, 0, 0, key
|
|
110
116
|
burnin_loop = functools.partial(inner_loop, tracelist=tracelist_burnin, burnin=True)
|
|
111
117
|
(bart, i_total, _, key), burnin_trace = lax.scan(burnin_loop, carry, None, n_burn)
|
|
112
118
|
else:
|
|
113
119
|
i_total = 0
|
|
114
|
-
burnin_trace =
|
|
115
|
-
key: jnp.empty((0,) + bart[key].shape, bart[key].dtype)
|
|
116
|
-
for key in tracelist_burnin
|
|
117
|
-
}
|
|
120
|
+
burnin_trace = empty_trace(bart, tracelist_burnin)
|
|
118
121
|
|
|
119
122
|
def outer_loop(carry, _):
|
|
120
123
|
bart, i_total, key = carry
|
|
@@ -124,8 +127,11 @@ def run_mcmc(bart, n_burn, n_save, n_skip, callback, key):
|
|
|
124
127
|
output = {key: bart[key] for key in tracelist_main}
|
|
125
128
|
return (bart, i_total, key), output
|
|
126
129
|
|
|
127
|
-
|
|
128
|
-
|
|
130
|
+
if n_save > 0:
|
|
131
|
+
carry = bart, i_total, key
|
|
132
|
+
(bart, _, _), main_trace = lax.scan(outer_loop, carry, None, n_save)
|
|
133
|
+
else:
|
|
134
|
+
main_trace = empty_trace(bart, tracelist_main)
|
|
129
135
|
|
|
130
136
|
return bart, burnin_trace, main_trace
|
|
131
137
|
|
|
@@ -133,7 +139,8 @@ def run_mcmc(bart, n_burn, n_save, n_skip, callback, key):
|
|
|
133
139
|
|
|
134
140
|
@functools.lru_cache
|
|
135
141
|
# cache to make the callback function object unique, such that the jit
|
|
136
|
-
# of run_mcmc recognizes it
|
|
142
|
+
# of run_mcmc recognizes it => with the callback state, I can make
|
|
143
|
+
# printevery a runtime quantity
|
|
137
144
|
def make_simple_print_callback(printevery):
|
|
138
145
|
"""
|
|
139
146
|
Create a logging callback function for MCMC iterations.
|
|
@@ -155,11 +162,12 @@ def make_simple_print_callback(printevery):
|
|
|
155
162
|
grow_acc = bart['grow_acc_count'] / bart['grow_prop_count']
|
|
156
163
|
prune_acc = bart['prune_acc_count'] / bart['prune_prop_count']
|
|
157
164
|
n_total = n_burn + n_save * n_skip
|
|
158
|
-
|
|
165
|
+
printcond = (i_total + 1) % printevery == 0
|
|
166
|
+
debug.callback(_simple_print_callback, burnin, i_total, n_total, grow_prop, grow_acc, prune_prop, prune_acc, printcond)
|
|
159
167
|
return callback
|
|
160
168
|
|
|
161
|
-
def
|
|
162
|
-
if
|
|
169
|
+
def _simple_print_callback(burnin, i_total, n_total, grow_prop, grow_acc, prune_prop, prune_acc, printcond):
|
|
170
|
+
if printcond:
|
|
163
171
|
burnin_flag = ' (burnin)' if burnin else ''
|
|
164
172
|
total_str = str(n_total)
|
|
165
173
|
ndigits = len(total_str)
|