bartz 0.1.0__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bartz/{interface.py → BART.py} +4 -3
- bartz/__init__.py +6 -1
- bartz/_version.py +1 -1
- bartz/debug.py +5 -4
- bartz/grove.py +36 -36
- bartz/jaxext.py +261 -5
- bartz/mcmcloop.py +19 -11
- bartz/mcmcstep.py +200 -73
- bartz/prepcovars.py +25 -30
- {bartz-0.1.0.dist-info → bartz-0.2.1.dist-info}/METADATA +7 -1
- bartz-0.2.1.dist-info/RECORD +13 -0
- bartz-0.1.0.dist-info/RECORD +0 -13
- {bartz-0.1.0.dist-info → bartz-0.2.1.dist-info}/LICENSE +0 -0
- {bartz-0.1.0.dist-info → bartz-0.2.1.dist-info}/WHEEL +0 -0
bartz/{interface.py → BART.py}
RENAMED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# bartz/src/bartz/
|
|
1
|
+
# bartz/src/bartz/BART.py
|
|
2
2
|
#
|
|
3
3
|
# Copyright (c) 2024, Giacomo Petrillo
|
|
4
4
|
#
|
|
@@ -33,7 +33,7 @@ from . import mcmcstep
|
|
|
33
33
|
from . import mcmcloop
|
|
34
34
|
from . import prepcovars
|
|
35
35
|
|
|
36
|
-
class
|
|
36
|
+
class gbart:
|
|
37
37
|
"""
|
|
38
38
|
Nonparametric regression with Bayesian Additive Regression Trees (BART).
|
|
39
39
|
|
|
@@ -133,7 +133,7 @@ class BART:
|
|
|
133
133
|
|
|
134
134
|
Notes
|
|
135
135
|
-----
|
|
136
|
-
This interface imitates the function `
|
|
136
|
+
This interface imitates the function `gbart` from the R package `BART
|
|
137
137
|
<https://cran.r-project.org/package=BART>`_, but with these differences:
|
|
138
138
|
|
|
139
139
|
- If `x_train` and `x_test` are matrices, they have one predictor per row
|
|
@@ -142,6 +142,7 @@ class BART:
|
|
|
142
142
|
- `usequants` is always `True`.
|
|
143
143
|
- `rm_const` is always `False`.
|
|
144
144
|
- The default `numcut` is 255 instead of 100.
|
|
145
|
+
- A lot of functionality is missing (variable selection, discrete response).
|
|
145
146
|
- There are some additional attributes, and some missing.
|
|
146
147
|
"""
|
|
147
148
|
|
bartz/__init__.py
CHANGED
|
@@ -30,6 +30,11 @@ See the manual at https://gattocrucco.github.io/bartz/docs
|
|
|
30
30
|
|
|
31
31
|
from ._version import __version__
|
|
32
32
|
|
|
33
|
-
from .
|
|
33
|
+
from . import BART
|
|
34
34
|
|
|
35
35
|
from . import debug
|
|
36
|
+
from . import grove
|
|
37
|
+
from . import mcmcstep
|
|
38
|
+
from . import mcmcloop
|
|
39
|
+
from . import prepcovars
|
|
40
|
+
from . import jaxext
|
bartz/_version.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = '0.1
|
|
1
|
+
__version__ = '0.2.1'
|
bartz/debug.py
CHANGED
|
@@ -6,6 +6,7 @@ from jax import lax
|
|
|
6
6
|
|
|
7
7
|
from . import grove
|
|
8
8
|
from . import mcmcstep
|
|
9
|
+
from . import jaxext
|
|
9
10
|
|
|
10
11
|
def print_tree(leaf_tree, var_tree, split_tree, print_all=False):
|
|
11
12
|
|
|
@@ -83,7 +84,7 @@ def trace_depth_distr(split_trees_trace):
|
|
|
83
84
|
def points_per_leaf_distr(var_tree, split_tree, X):
|
|
84
85
|
traverse_tree = jax.vmap(grove.traverse_tree, in_axes=(1, None, None))
|
|
85
86
|
indices = traverse_tree(X, var_tree, split_tree)
|
|
86
|
-
count_tree = jnp.zeros(2 * split_tree.size, dtype=
|
|
87
|
+
count_tree = jnp.zeros(2 * split_tree.size, dtype=jaxext.minimal_unsigned_dtype(indices.size))
|
|
87
88
|
count_tree = count_tree.at[indices].add(1)
|
|
88
89
|
is_leaf = grove.is_actual_leaf(split_tree, add_bottom_level=True).view(jnp.uint8)
|
|
89
90
|
return jnp.bincount(count_tree, is_leaf, length=X.shape[1] + 1)
|
|
@@ -103,7 +104,7 @@ def trace_points_per_leaf_distr(bart, X):
|
|
|
103
104
|
return distr
|
|
104
105
|
|
|
105
106
|
def check_types(leaf_tree, var_tree, split_tree, max_split):
|
|
106
|
-
expected_var_dtype =
|
|
107
|
+
expected_var_dtype = jaxext.minimal_unsigned_dtype(max_split.size - 1)
|
|
107
108
|
expected_split_dtype = max_split.dtype
|
|
108
109
|
return var_tree.dtype == expected_var_dtype and split_tree.dtype == expected_split_dtype
|
|
109
110
|
|
|
@@ -117,7 +118,7 @@ def check_leaf_values(leaf_tree, var_tree, split_tree, max_split):
|
|
|
117
118
|
return jnp.all(jnp.isfinite(leaf_tree))
|
|
118
119
|
|
|
119
120
|
def check_stray_nodes(leaf_tree, var_tree, split_tree, max_split):
|
|
120
|
-
index = jnp.arange(2 * split_tree.size, dtype=
|
|
121
|
+
index = jnp.arange(2 * split_tree.size, dtype=jaxext.minimal_unsigned_dtype(2 * split_tree.size - 1))
|
|
121
122
|
parent_index = index >> 1
|
|
122
123
|
is_not_leaf = split_tree.at[index].get(mode='fill', fill_value=0) != 0
|
|
123
124
|
parent_is_leaf = split_tree[parent_index] == 0
|
|
@@ -134,7 +135,7 @@ check_functions = [
|
|
|
134
135
|
]
|
|
135
136
|
|
|
136
137
|
def check_tree(leaf_tree, var_tree, split_tree, max_split):
|
|
137
|
-
error_type =
|
|
138
|
+
error_type = jaxext.minimal_unsigned_dtype(2 ** len(check_functions) - 1)
|
|
138
139
|
error = error_type(0)
|
|
139
140
|
for i, func in enumerate(check_functions):
|
|
140
141
|
ok = func(leaf_tree, var_tree, split_tree, max_split)
|
bartz/grove.py
CHANGED
|
@@ -44,7 +44,6 @@ import functools
|
|
|
44
44
|
import math
|
|
45
45
|
|
|
46
46
|
import jax
|
|
47
|
-
|
|
48
47
|
from jax import numpy as jnp
|
|
49
48
|
from jax import lax
|
|
50
49
|
|
|
@@ -107,29 +106,47 @@ def traverse_tree(x, var_tree, split_tree):
|
|
|
107
106
|
|
|
108
107
|
carry = (
|
|
109
108
|
jnp.zeros((), bool),
|
|
110
|
-
jnp.ones((), minimal_unsigned_dtype(2 * var_tree.size - 1)),
|
|
109
|
+
jnp.ones((), jaxext.minimal_unsigned_dtype(2 * var_tree.size - 1)),
|
|
111
110
|
)
|
|
112
111
|
|
|
113
112
|
def loop(carry, _):
|
|
114
113
|
leaf_found, index = carry
|
|
115
114
|
|
|
116
|
-
split = split_tree
|
|
117
|
-
var = var_tree
|
|
115
|
+
split = split_tree[index]
|
|
116
|
+
var = var_tree[index]
|
|
118
117
|
|
|
119
|
-
leaf_found |=
|
|
118
|
+
leaf_found |= split == 0
|
|
120
119
|
child_index = (index << 1) + (x[var] >= split)
|
|
121
120
|
index = jnp.where(leaf_found, index, child_index)
|
|
122
121
|
|
|
123
122
|
return (leaf_found, index), None
|
|
124
123
|
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
# - separate and special-case the last iteration
|
|
128
|
-
|
|
129
|
-
depth = 1 + tree_depth(var_tree)
|
|
130
|
-
(_, index), _ = lax.scan(loop, carry, None, depth)
|
|
124
|
+
depth = tree_depth(var_tree)
|
|
125
|
+
(_, index), _ = lax.scan(loop, carry, None, depth, unroll=16)
|
|
131
126
|
return index
|
|
132
127
|
|
|
128
|
+
@functools.partial(jaxext.vmap_nodoc, in_axes=(None, 0, 0))
|
|
129
|
+
@functools.partial(jaxext.vmap_nodoc, in_axes=(1, None, None))
|
|
130
|
+
def traverse_forest(X, var_trees, split_trees):
|
|
131
|
+
"""
|
|
132
|
+
Find the leaves where points fall into.
|
|
133
|
+
|
|
134
|
+
Parameters
|
|
135
|
+
----------
|
|
136
|
+
X : array (p, n)
|
|
137
|
+
The coordinates to evaluate the trees at.
|
|
138
|
+
var_trees : array (m, 2 ** (d - 1))
|
|
139
|
+
The decision axes of the trees.
|
|
140
|
+
split_trees : array (m, 2 ** (d - 1))
|
|
141
|
+
The decision boundaries of the trees.
|
|
142
|
+
|
|
143
|
+
Returns
|
|
144
|
+
-------
|
|
145
|
+
indices : array (m, n)
|
|
146
|
+
The indices of the leaves.
|
|
147
|
+
"""
|
|
148
|
+
return traverse_tree(X, var_trees, split_trees)
|
|
149
|
+
|
|
133
150
|
def evaluate_forest(X, leaf_trees, var_trees, split_trees, dtype):
|
|
134
151
|
"""
|
|
135
152
|
Evaluate a ensemble of trees at an array of points.
|
|
@@ -138,7 +155,7 @@ def evaluate_forest(X, leaf_trees, var_trees, split_trees, dtype):
|
|
|
138
155
|
----------
|
|
139
156
|
X : array (p, n)
|
|
140
157
|
The coordinates to evaluate the trees at.
|
|
141
|
-
leaf_trees : (m, 2 ** d)
|
|
158
|
+
leaf_trees : array (m, 2 ** d)
|
|
142
159
|
The leaf values of the tree or forest. If the input is a forest, the
|
|
143
160
|
first axis is the tree index, and the values are summed.
|
|
144
161
|
var_trees : array (m, 2 ** (d - 1))
|
|
@@ -153,30 +170,13 @@ def evaluate_forest(X, leaf_trees, var_trees, split_trees, dtype):
|
|
|
153
170
|
out : array (n,)
|
|
154
171
|
The sum of the values of the trees at the points in `X`.
|
|
155
172
|
"""
|
|
156
|
-
indices =
|
|
173
|
+
indices = traverse_forest(X, var_trees, split_trees)
|
|
157
174
|
ntree, _ = leaf_trees.shape
|
|
158
|
-
tree_index = jnp.arange(ntree, dtype=minimal_unsigned_dtype(ntree - 1))[:, None]
|
|
175
|
+
tree_index = jnp.arange(ntree, dtype=jaxext.minimal_unsigned_dtype(ntree - 1))[:, None]
|
|
159
176
|
leaves = leaf_trees[tree_index, indices]
|
|
160
177
|
return jnp.sum(leaves, axis=0, dtype=dtype)
|
|
161
|
-
# this sum suggests to swap the vmaps, but I think it's better for X
|
|
162
|
-
|
|
163
|
-
@functools.partial(jax.vmap, in_axes=(None, 0, 0))
|
|
164
|
-
@functools.partial(jax.vmap, in_axes=(1, None, None))
|
|
165
|
-
def _traverse_forest(X, var_trees, split_trees):
|
|
166
|
-
return traverse_tree(X, var_trees, split_trees)
|
|
167
|
-
|
|
168
|
-
def minimal_unsigned_dtype(max_value):
|
|
169
|
-
"""
|
|
170
|
-
Return the smallest unsigned integer dtype that can represent a given
|
|
171
|
-
maximum value.
|
|
172
|
-
"""
|
|
173
|
-
if max_value < 2 ** 8:
|
|
174
|
-
return jnp.uint8
|
|
175
|
-
if max_value < 2 ** 16:
|
|
176
|
-
return jnp.uint16
|
|
177
|
-
if max_value < 2 ** 32:
|
|
178
|
-
return jnp.uint32
|
|
179
|
-
return jnp.uint64
|
|
178
|
+
# this sum suggests to swap the vmaps, but I think it's better for X
|
|
179
|
+
# copying to keep it that way
|
|
180
180
|
|
|
181
181
|
def is_actual_leaf(split_tree, *, add_bottom_level=False):
|
|
182
182
|
"""
|
|
@@ -200,7 +200,7 @@ def is_actual_leaf(split_tree, *, add_bottom_level=False):
|
|
|
200
200
|
if add_bottom_level:
|
|
201
201
|
size *= 2
|
|
202
202
|
is_leaf = jnp.concatenate([is_leaf, jnp.ones_like(is_leaf)])
|
|
203
|
-
index = jnp.arange(size, dtype=minimal_unsigned_dtype(size - 1))
|
|
203
|
+
index = jnp.arange(size, dtype=jaxext.minimal_unsigned_dtype(size - 1))
|
|
204
204
|
parent_index = index >> 1
|
|
205
205
|
parent_nonleaf = split_tree[parent_index].astype(bool)
|
|
206
206
|
parent_nonleaf = parent_nonleaf.at[1].set(True)
|
|
@@ -220,7 +220,7 @@ def is_leaves_parent(split_tree):
|
|
|
220
220
|
is_leaves_parent : bool array (2 ** (d - 1),)
|
|
221
221
|
The mask indicating which nodes have leaf children.
|
|
222
222
|
"""
|
|
223
|
-
index = jnp.arange(split_tree.size, dtype=minimal_unsigned_dtype(2 * split_tree.size - 1))
|
|
223
|
+
index = jnp.arange(split_tree.size, dtype=jaxext.minimal_unsigned_dtype(2 * split_tree.size - 1))
|
|
224
224
|
left_index = index << 1 # left child
|
|
225
225
|
right_index = left_index + 1 # right child
|
|
226
226
|
left_leaf = split_tree.at[left_index].get(mode='fill', fill_value=0) == 0
|
|
@@ -252,4 +252,4 @@ def tree_depths(tree_length):
|
|
|
252
252
|
depth += 1
|
|
253
253
|
depths.append(depth - 1)
|
|
254
254
|
depths[0] = 0
|
|
255
|
-
return jnp.array(depths, minimal_unsigned_dtype(max(depths)))
|
|
255
|
+
return jnp.array(depths, jaxext.minimal_unsigned_dtype(max(depths)))
|
bartz/jaxext.py
CHANGED
|
@@ -10,10 +10,10 @@
|
|
|
10
10
|
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
11
11
|
# copies of the Software, and to permit persons to whom the Software is
|
|
12
12
|
# furnished to do so, subject to the following conditions:
|
|
13
|
-
#
|
|
13
|
+
#
|
|
14
14
|
# The above copyright notice and this permission notice shall be included in all
|
|
15
15
|
# copies or substantial portions of the Software.
|
|
16
|
-
#
|
|
16
|
+
#
|
|
17
17
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
18
18
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
19
19
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
@@ -23,12 +23,19 @@
|
|
|
23
23
|
# SOFTWARE.
|
|
24
24
|
|
|
25
25
|
import functools
|
|
26
|
+
import math
|
|
27
|
+
import warnings
|
|
26
28
|
|
|
27
29
|
from scipy import special
|
|
28
30
|
import jax
|
|
29
31
|
from jax import numpy as jnp
|
|
32
|
+
from jax import tree_util
|
|
33
|
+
from jax import lax
|
|
30
34
|
|
|
31
35
|
def float_type(*args):
|
|
36
|
+
"""
|
|
37
|
+
Determine the jax floating point result type given operands/types.
|
|
38
|
+
"""
|
|
32
39
|
t = jnp.result_type(*args)
|
|
33
40
|
return jnp.sin(jnp.empty(0, t)).dtype
|
|
34
41
|
|
|
@@ -39,8 +46,8 @@ def castto(func, type):
|
|
|
39
46
|
return newfunc
|
|
40
47
|
|
|
41
48
|
def pure_callback_ufunc(callback, dtype, *args, excluded=None, **kwargs):
|
|
42
|
-
""" version of jax.pure_callback that deals correctly with ufuncs,
|
|
43
|
-
see https://github.com/google/jax/issues/17187 """
|
|
49
|
+
""" version of `jax.pure_callback` that deals correctly with ufuncs,
|
|
50
|
+
see `<https://github.com/google/jax/issues/17187>`_ """
|
|
44
51
|
if excluded is None:
|
|
45
52
|
excluded = ()
|
|
46
53
|
shape = jnp.broadcast_shapes(*(
|
|
@@ -63,6 +70,7 @@ class scipy:
|
|
|
63
70
|
|
|
64
71
|
class special:
|
|
65
72
|
|
|
73
|
+
@functools.wraps(special.gammainccinv)
|
|
66
74
|
def gammainccinv(a, y):
|
|
67
75
|
a = jnp.asarray(a)
|
|
68
76
|
y = jnp.asarray(y)
|
|
@@ -73,13 +81,261 @@ class scipy:
|
|
|
73
81
|
class stats:
|
|
74
82
|
|
|
75
83
|
class invgamma:
|
|
76
|
-
|
|
84
|
+
|
|
77
85
|
def ppf(q, a):
|
|
78
86
|
return 1 / scipy.special.gammainccinv(a, q)
|
|
79
87
|
|
|
80
88
|
@functools.wraps(jax.vmap)
|
|
81
89
|
def vmap_nodoc(fun, *args, **kw):
|
|
90
|
+
"""
|
|
91
|
+
Version of `jax.vmap` that preserves the docstring of the input function.
|
|
92
|
+
"""
|
|
82
93
|
doc = fun.__doc__
|
|
83
94
|
fun = jax.vmap(fun, *args, **kw)
|
|
84
95
|
fun.__doc__ = doc
|
|
85
96
|
return fun
|
|
97
|
+
|
|
98
|
+
def huge_value(x):
|
|
99
|
+
"""
|
|
100
|
+
Return the maximum value that can be stored in `x`.
|
|
101
|
+
|
|
102
|
+
Parameters
|
|
103
|
+
----------
|
|
104
|
+
x : array
|
|
105
|
+
A numerical numpy or jax array.
|
|
106
|
+
|
|
107
|
+
Returns
|
|
108
|
+
-------
|
|
109
|
+
maxval : scalar
|
|
110
|
+
The maximum value allowed by `x`'s type (+inf for floats).
|
|
111
|
+
"""
|
|
112
|
+
if jnp.issubdtype(x.dtype, jnp.integer):
|
|
113
|
+
return jnp.iinfo(x.dtype).max
|
|
114
|
+
else:
|
|
115
|
+
return jnp.inf
|
|
116
|
+
|
|
117
|
+
def minimal_unsigned_dtype(max_value):
|
|
118
|
+
"""
|
|
119
|
+
Return the smallest unsigned integer dtype that can represent a given
|
|
120
|
+
maximum value (inclusive).
|
|
121
|
+
"""
|
|
122
|
+
if max_value < 2 ** 8:
|
|
123
|
+
return jnp.uint8
|
|
124
|
+
if max_value < 2 ** 16:
|
|
125
|
+
return jnp.uint16
|
|
126
|
+
if max_value < 2 ** 32:
|
|
127
|
+
return jnp.uint32
|
|
128
|
+
return jnp.uint64
|
|
129
|
+
|
|
130
|
+
def signed_to_unsigned(int_dtype):
|
|
131
|
+
"""
|
|
132
|
+
Map a signed integer type to its unsigned counterpart. Unsigned types are
|
|
133
|
+
passed through.
|
|
134
|
+
"""
|
|
135
|
+
assert jnp.issubdtype(int_dtype, jnp.integer)
|
|
136
|
+
if jnp.issubdtype(int_dtype, jnp.unsignedinteger):
|
|
137
|
+
return int_dtype
|
|
138
|
+
if int_dtype == jnp.int8:
|
|
139
|
+
return jnp.uint8
|
|
140
|
+
if int_dtype == jnp.int16:
|
|
141
|
+
return jnp.uint16
|
|
142
|
+
if int_dtype == jnp.int32:
|
|
143
|
+
return jnp.uint32
|
|
144
|
+
if int_dtype == jnp.int64:
|
|
145
|
+
return jnp.uint64
|
|
146
|
+
|
|
147
|
+
def ensure_unsigned(x):
|
|
148
|
+
"""
|
|
149
|
+
If x has signed integer type, cast it to the unsigned dtype of the same size.
|
|
150
|
+
"""
|
|
151
|
+
return x.astype(signed_to_unsigned(x.dtype))
|
|
152
|
+
|
|
153
|
+
@functools.partial(jax.jit, static_argnums=(1,))
|
|
154
|
+
def unique(x, size, fill_value):
|
|
155
|
+
"""
|
|
156
|
+
Restricted version of `jax.numpy.unique` that uses less memory.
|
|
157
|
+
|
|
158
|
+
Parameters
|
|
159
|
+
----------
|
|
160
|
+
x : 1d array
|
|
161
|
+
The input array.
|
|
162
|
+
size : int
|
|
163
|
+
The length of the output.
|
|
164
|
+
fill_value : scalar
|
|
165
|
+
The value to fill the output with if `size` is greater than the number
|
|
166
|
+
of unique values in `x`.
|
|
167
|
+
|
|
168
|
+
Returns
|
|
169
|
+
-------
|
|
170
|
+
out : array (size,)
|
|
171
|
+
The unique values in `x`, sorted, and right-padded with `fill_value`.
|
|
172
|
+
actual_length : int
|
|
173
|
+
The number of used values in `out`.
|
|
174
|
+
"""
|
|
175
|
+
if x.size == 0:
|
|
176
|
+
return jnp.full(size, fill_value, x.dtype), 0
|
|
177
|
+
if size == 0:
|
|
178
|
+
return jnp.empty(0, x.dtype), 0
|
|
179
|
+
x = jnp.sort(x)
|
|
180
|
+
def loop(carry, x):
|
|
181
|
+
i_out, i_in, last, out = carry
|
|
182
|
+
i_out = jnp.where(x == last, i_out, i_out + 1)
|
|
183
|
+
out = out.at[i_out].set(x)
|
|
184
|
+
return (i_out, i_in + 1, x, out), None
|
|
185
|
+
carry = 0, 0, x[0], jnp.full(size, fill_value, x.dtype)
|
|
186
|
+
(actual_length, _, _, out), _ = jax.lax.scan(loop, carry, x[:size])
|
|
187
|
+
return out, actual_length + 1
|
|
188
|
+
|
|
189
|
+
def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False):
|
|
190
|
+
"""
|
|
191
|
+
Batch a function such that each batch is smaller than a threshold.
|
|
192
|
+
|
|
193
|
+
Parameters
|
|
194
|
+
----------
|
|
195
|
+
func : callable
|
|
196
|
+
A jittable function with positional arguments only, with inputs and
|
|
197
|
+
outputs pytrees of arrays.
|
|
198
|
+
max_io_nbytes : int
|
|
199
|
+
The maximum number of input + output bytes in each batch.
|
|
200
|
+
in_axes : pytree of ints, default 0
|
|
201
|
+
A tree matching the structure of the function input, indicating along
|
|
202
|
+
which axes each array should be batched. If a single integer, it is
|
|
203
|
+
used for all arrays.
|
|
204
|
+
out_axes : pytree of ints, default 0
|
|
205
|
+
The same for outputs.
|
|
206
|
+
return_nbatches : bool, default False
|
|
207
|
+
If True, the number of batches is returned as a second output.
|
|
208
|
+
|
|
209
|
+
Returns
|
|
210
|
+
-------
|
|
211
|
+
batched_func : callable
|
|
212
|
+
A function with the same signature as `func`, but that processes the
|
|
213
|
+
input and output in batches in a loop.
|
|
214
|
+
"""
|
|
215
|
+
|
|
216
|
+
def expand_axes(axes, tree):
|
|
217
|
+
if isinstance(axes, int):
|
|
218
|
+
return tree_util.tree_map(lambda _: axes, tree)
|
|
219
|
+
return tree_util.tree_map(lambda _, axis: axis, tree, axes)
|
|
220
|
+
|
|
221
|
+
def extract_size(axes, tree):
|
|
222
|
+
sizes = tree_util.tree_map(lambda x, axis: x.shape[axis], tree, axes)
|
|
223
|
+
sizes, _ = tree_util.tree_flatten(sizes)
|
|
224
|
+
assert all(s == sizes[0] for s in sizes)
|
|
225
|
+
return sizes[0]
|
|
226
|
+
|
|
227
|
+
def sum_nbytes(tree):
|
|
228
|
+
def nbytes(x):
|
|
229
|
+
return math.prod(x.shape) * x.dtype.itemsize
|
|
230
|
+
return tree_util.tree_reduce(lambda size, x: size + nbytes(x), tree, 0)
|
|
231
|
+
|
|
232
|
+
def next_divisor_small(dividend, min_divisor):
|
|
233
|
+
for divisor in range(min_divisor, int(math.sqrt(dividend)) + 1):
|
|
234
|
+
if dividend % divisor == 0:
|
|
235
|
+
return divisor
|
|
236
|
+
return dividend
|
|
237
|
+
|
|
238
|
+
def next_divisor_large(dividend, min_divisor):
|
|
239
|
+
max_inv_divisor = dividend // min_divisor
|
|
240
|
+
for inv_divisor in range(max_inv_divisor, 0, -1):
|
|
241
|
+
if dividend % inv_divisor == 0:
|
|
242
|
+
return dividend // inv_divisor
|
|
243
|
+
return dividend
|
|
244
|
+
|
|
245
|
+
def next_divisor(dividend, min_divisor):
|
|
246
|
+
if min_divisor * min_divisor <= dividend:
|
|
247
|
+
return next_divisor_small(dividend, min_divisor)
|
|
248
|
+
return next_divisor_large(dividend, min_divisor)
|
|
249
|
+
|
|
250
|
+
def move_axes_out(axes, tree):
|
|
251
|
+
def move_axis_out(axis, x):
|
|
252
|
+
if axis != 0:
|
|
253
|
+
return jnp.moveaxis(x, axis, 0)
|
|
254
|
+
return x
|
|
255
|
+
return tree_util.tree_map(move_axis_out, axes, tree)
|
|
256
|
+
|
|
257
|
+
def move_axes_in(axes, tree):
|
|
258
|
+
def move_axis_in(axis, x):
|
|
259
|
+
if axis != 0:
|
|
260
|
+
return jnp.moveaxis(x, 0, axis)
|
|
261
|
+
return x
|
|
262
|
+
return tree_util.tree_map(move_axis_in, axes, tree)
|
|
263
|
+
|
|
264
|
+
def batch(tree, nbatches):
|
|
265
|
+
def batch(x):
|
|
266
|
+
return x.reshape((nbatches, x.shape[0] // nbatches) + x.shape[1:])
|
|
267
|
+
return tree_util.tree_map(batch, tree)
|
|
268
|
+
|
|
269
|
+
def unbatch(tree):
|
|
270
|
+
def unbatch(x):
|
|
271
|
+
return x.reshape((x.shape[0] * x.shape[1],) + x.shape[2:])
|
|
272
|
+
return tree_util.tree_map(unbatch, tree)
|
|
273
|
+
|
|
274
|
+
def check_same(tree1, tree2):
|
|
275
|
+
def check_same(x1, x2):
|
|
276
|
+
assert x1.shape == x2.shape
|
|
277
|
+
assert x1.dtype == x2.dtype
|
|
278
|
+
tree_util.tree_map(check_same, tree1, tree2)
|
|
279
|
+
|
|
280
|
+
initial_in_axes = in_axes
|
|
281
|
+
initial_out_axes = out_axes
|
|
282
|
+
|
|
283
|
+
@jax.jit
|
|
284
|
+
@functools.wraps(func)
|
|
285
|
+
def batched_func(*args):
|
|
286
|
+
example_result = jax.eval_shape(func, *args)
|
|
287
|
+
|
|
288
|
+
in_axes = expand_axes(initial_in_axes, args)
|
|
289
|
+
out_axes = expand_axes(initial_out_axes, example_result)
|
|
290
|
+
|
|
291
|
+
in_size = extract_size(in_axes, args)
|
|
292
|
+
out_size = extract_size(out_axes, example_result)
|
|
293
|
+
assert in_size == out_size
|
|
294
|
+
size = in_size
|
|
295
|
+
|
|
296
|
+
total_nbytes = sum_nbytes(args) + sum_nbytes(example_result)
|
|
297
|
+
min_nbatches = total_nbytes // max_io_nbytes + bool(total_nbytes % max_io_nbytes)
|
|
298
|
+
nbatches = next_divisor(size, min_nbatches)
|
|
299
|
+
assert 1 <= nbatches <= size
|
|
300
|
+
assert size % nbatches == 0
|
|
301
|
+
assert total_nbytes % nbatches == 0
|
|
302
|
+
|
|
303
|
+
batch_nbytes = total_nbytes // nbatches
|
|
304
|
+
if batch_nbytes > max_io_nbytes:
|
|
305
|
+
assert size == nbatches
|
|
306
|
+
warnings.warn(f'batch_nbytes = {batch_nbytes} > max_io_nbytes = {max_io_nbytes}')
|
|
307
|
+
|
|
308
|
+
def loop(_, args):
|
|
309
|
+
args = move_axes_in(in_axes, args)
|
|
310
|
+
result = func(*args)
|
|
311
|
+
result = move_axes_out(out_axes, result)
|
|
312
|
+
return None, result
|
|
313
|
+
|
|
314
|
+
args = move_axes_out(in_axes, args)
|
|
315
|
+
args = batch(args, nbatches)
|
|
316
|
+
_, result = lax.scan(loop, None, args)
|
|
317
|
+
result = unbatch(result)
|
|
318
|
+
result = move_axes_in(out_axes, result)
|
|
319
|
+
|
|
320
|
+
check_same(example_result, result)
|
|
321
|
+
|
|
322
|
+
if return_nbatches:
|
|
323
|
+
return result, nbatches
|
|
324
|
+
return result
|
|
325
|
+
|
|
326
|
+
return batched_func
|
|
327
|
+
|
|
328
|
+
@tree_util.register_pytree_node_class
|
|
329
|
+
class LeafDict(dict):
|
|
330
|
+
""" dictionary that acts as a leaf in jax pytrees, to store compile-time
|
|
331
|
+
values """
|
|
332
|
+
|
|
333
|
+
def tree_flatten(self):
|
|
334
|
+
return (), self
|
|
335
|
+
|
|
336
|
+
@classmethod
|
|
337
|
+
def tree_unflatten(cls, aux_data, children):
|
|
338
|
+
return aux_data
|
|
339
|
+
|
|
340
|
+
def __repr__(self):
|
|
341
|
+
return f'{__class__.__name__}({super().__repr__()})'
|
bartz/mcmcloop.py
CHANGED
|
@@ -52,7 +52,7 @@ def run_mcmc(bart, n_burn, n_save, n_skip, callback, key):
|
|
|
52
52
|
n_save : int
|
|
53
53
|
The number of iterations to save.
|
|
54
54
|
n_skip : int
|
|
55
|
-
The number of iterations to skip between each saved iteration.
|
|
55
|
+
The number of iterations to skip between each saved iteration, plus 1.
|
|
56
56
|
callback : callable
|
|
57
57
|
An arbitrary function run at each iteration, called with the following
|
|
58
58
|
arguments, passed by keyword:
|
|
@@ -105,16 +105,19 @@ def run_mcmc(bart, n_burn, n_save, n_skip, callback, key):
|
|
|
105
105
|
output = {key: bart[key] for key in tracelist}
|
|
106
106
|
return (bart, i_total + 1, i_skip + 1, key), output
|
|
107
107
|
|
|
108
|
+
def empty_trace(bart, tracelist):
|
|
109
|
+
return {
|
|
110
|
+
key: jnp.empty((0,) + bart[key].shape, bart[key].dtype)
|
|
111
|
+
for key in tracelist
|
|
112
|
+
}
|
|
113
|
+
|
|
108
114
|
if n_burn > 0:
|
|
109
115
|
carry = bart, 0, 0, key
|
|
110
116
|
burnin_loop = functools.partial(inner_loop, tracelist=tracelist_burnin, burnin=True)
|
|
111
117
|
(bart, i_total, _, key), burnin_trace = lax.scan(burnin_loop, carry, None, n_burn)
|
|
112
118
|
else:
|
|
113
119
|
i_total = 0
|
|
114
|
-
burnin_trace =
|
|
115
|
-
key: jnp.empty((0,) + bart[key].shape, bart[key].dtype)
|
|
116
|
-
for key in tracelist_burnin
|
|
117
|
-
}
|
|
120
|
+
burnin_trace = empty_trace(bart, tracelist_burnin)
|
|
118
121
|
|
|
119
122
|
def outer_loop(carry, _):
|
|
120
123
|
bart, i_total, key = carry
|
|
@@ -124,8 +127,11 @@ def run_mcmc(bart, n_burn, n_save, n_skip, callback, key):
|
|
|
124
127
|
output = {key: bart[key] for key in tracelist_main}
|
|
125
128
|
return (bart, i_total, key), output
|
|
126
129
|
|
|
127
|
-
|
|
128
|
-
|
|
130
|
+
if n_save > 0:
|
|
131
|
+
carry = bart, i_total, key
|
|
132
|
+
(bart, _, _), main_trace = lax.scan(outer_loop, carry, None, n_save)
|
|
133
|
+
else:
|
|
134
|
+
main_trace = empty_trace(bart, tracelist_main)
|
|
129
135
|
|
|
130
136
|
return bart, burnin_trace, main_trace
|
|
131
137
|
|
|
@@ -133,7 +139,8 @@ def run_mcmc(bart, n_burn, n_save, n_skip, callback, key):
|
|
|
133
139
|
|
|
134
140
|
@functools.lru_cache
|
|
135
141
|
# cache to make the callback function object unique, such that the jit
|
|
136
|
-
# of run_mcmc recognizes it
|
|
142
|
+
# of run_mcmc recognizes it => with the callback state, I can make
|
|
143
|
+
# printevery a runtime quantity
|
|
137
144
|
def make_simple_print_callback(printevery):
|
|
138
145
|
"""
|
|
139
146
|
Create a logging callback function for MCMC iterations.
|
|
@@ -155,11 +162,12 @@ def make_simple_print_callback(printevery):
|
|
|
155
162
|
grow_acc = bart['grow_acc_count'] / bart['grow_prop_count']
|
|
156
163
|
prune_acc = bart['prune_acc_count'] / bart['prune_prop_count']
|
|
157
164
|
n_total = n_burn + n_save * n_skip
|
|
158
|
-
|
|
165
|
+
printcond = (i_total + 1) % printevery == 0
|
|
166
|
+
debug.callback(_simple_print_callback, burnin, i_total, n_total, grow_prop, grow_acc, prune_prop, prune_acc, printcond)
|
|
159
167
|
return callback
|
|
160
168
|
|
|
161
|
-
def
|
|
162
|
-
if
|
|
169
|
+
def _simple_print_callback(burnin, i_total, n_total, grow_prop, grow_acc, prune_prop, prune_acc, printcond):
|
|
170
|
+
if printcond:
|
|
163
171
|
burnin_flag = ' (burnin)' if burnin else ''
|
|
164
172
|
total_str = str(n_total)
|
|
165
173
|
ndigits = len(total_str)
|
bartz/mcmcstep.py
CHANGED
|
@@ -34,7 +34,6 @@ range of possible values.
|
|
|
34
34
|
"""
|
|
35
35
|
|
|
36
36
|
import functools
|
|
37
|
-
import math
|
|
38
37
|
|
|
39
38
|
import jax
|
|
40
39
|
from jax import random
|
|
@@ -55,6 +54,7 @@ def init(*,
|
|
|
55
54
|
small_float=jnp.float32,
|
|
56
55
|
large_float=jnp.float32,
|
|
57
56
|
min_points_per_leaf=None,
|
|
57
|
+
suffstat_batch_size='auto',
|
|
58
58
|
):
|
|
59
59
|
"""
|
|
60
60
|
Make a BART posterior sampling MCMC initial state.
|
|
@@ -82,6 +82,10 @@ def init(*,
|
|
|
82
82
|
The dtype for scalars, small arrays, and arrays which require accuracy.
|
|
83
83
|
min_points_per_leaf : int, optional
|
|
84
84
|
The minimum number of data points in a leaf node. 0 if not specified.
|
|
85
|
+
suffstat_batch_size : int, None, str, default 'auto'
|
|
86
|
+
The batch size for computing sufficient statistics. `None` for no
|
|
87
|
+
batching. If 'auto', pick a value based on the device of `y`, or the
|
|
88
|
+
default device.
|
|
85
89
|
|
|
86
90
|
Returns
|
|
87
91
|
-------
|
|
@@ -104,8 +108,9 @@ def init(*,
|
|
|
104
108
|
The number of grow/prune proposals made during one full MCMC cycle.
|
|
105
109
|
'grow_acc_count', 'prune_acc_count' : int
|
|
106
110
|
The number of grow/prune moves accepted during one full MCMC cycle.
|
|
107
|
-
'p_nonterminal' : large_float array (d
|
|
108
|
-
The probability of a nonterminal node at each depth
|
|
111
|
+
'p_nonterminal' : large_float array (d,)
|
|
112
|
+
The probability of a nonterminal node at each depth, padded with a
|
|
113
|
+
zero.
|
|
109
114
|
'sigma2_alpha' : large_float
|
|
110
115
|
The shape parameter of the inverse gamma prior on the noise variance.
|
|
111
116
|
'sigma2_beta' : large_float
|
|
@@ -121,18 +126,36 @@ def init(*,
|
|
|
121
126
|
'affluence_trees' : bool array (num_trees, 2 ** (d - 1)) or None
|
|
122
127
|
Whether a non-bottom leaf nodes contains twice `min_points_per_leaf`
|
|
123
128
|
datapoints. If `min_points_per_leaf` is not specified, this is None.
|
|
129
|
+
'opt' : LeafDict
|
|
130
|
+
A dictionary with config values:
|
|
131
|
+
|
|
132
|
+
'suffstat_batch_size' : int or None
|
|
133
|
+
The batch size for computing sufficient statistics.
|
|
134
|
+
'small_float' : dtype
|
|
135
|
+
The dtype for large arrays used in the algorithm.
|
|
136
|
+
'large_float' : dtype
|
|
137
|
+
The dtype for scalars, small arrays, and arrays which require
|
|
138
|
+
accuracy.
|
|
139
|
+
'require_min_points' : bool
|
|
140
|
+
Whether the `min_points_per_leaf` parameter is specified.
|
|
124
141
|
"""
|
|
125
142
|
|
|
126
143
|
p_nonterminal = jnp.asarray(p_nonterminal, large_float)
|
|
127
|
-
|
|
144
|
+
p_nonterminal = jnp.pad(p_nonterminal, (0, 1))
|
|
145
|
+
max_depth = p_nonterminal.size
|
|
128
146
|
|
|
129
147
|
@functools.partial(jax.vmap, in_axes=None, out_axes=0, axis_size=num_trees)
|
|
130
148
|
def make_forest(max_depth, dtype):
|
|
131
149
|
return grove.make_tree(max_depth, dtype)
|
|
132
150
|
|
|
151
|
+
small_float = jnp.dtype(small_float)
|
|
152
|
+
large_float = jnp.dtype(large_float)
|
|
153
|
+
y = jnp.asarray(y, small_float)
|
|
154
|
+
suffstat_batch_size = _choose_suffstat_batch_size(suffstat_batch_size, y)
|
|
155
|
+
|
|
133
156
|
bart = dict(
|
|
134
157
|
leaf_trees=make_forest(max_depth, small_float),
|
|
135
|
-
var_trees=make_forest(max_depth - 1,
|
|
158
|
+
var_trees=make_forest(max_depth - 1, jaxext.minimal_unsigned_dtype(X.shape[0] - 1)),
|
|
136
159
|
split_trees=make_forest(max_depth - 1, max_split.dtype),
|
|
137
160
|
resid=jnp.asarray(y, large_float),
|
|
138
161
|
sigma2=jnp.ones((), large_float),
|
|
@@ -143,9 +166,9 @@ def init(*,
|
|
|
143
166
|
p_nonterminal=p_nonterminal,
|
|
144
167
|
sigma2_alpha=jnp.asarray(sigma2_alpha, large_float),
|
|
145
168
|
sigma2_beta=jnp.asarray(sigma2_beta, large_float),
|
|
146
|
-
max_split=max_split,
|
|
147
|
-
y=
|
|
148
|
-
X=X,
|
|
169
|
+
max_split=jnp.asarray(max_split),
|
|
170
|
+
y=y,
|
|
171
|
+
X=jnp.asarray(X),
|
|
149
172
|
min_points_per_leaf=(
|
|
150
173
|
None if min_points_per_leaf is None else
|
|
151
174
|
jnp.asarray(min_points_per_leaf)
|
|
@@ -154,10 +177,39 @@ def init(*,
|
|
|
154
177
|
None if min_points_per_leaf is None else
|
|
155
178
|
make_forest(max_depth - 1, bool).at[:, 1].set(y.size >= 2 * min_points_per_leaf)
|
|
156
179
|
),
|
|
180
|
+
opt=jaxext.LeafDict(
|
|
181
|
+
suffstat_batch_size=suffstat_batch_size,
|
|
182
|
+
small_float=small_float,
|
|
183
|
+
large_float=large_float,
|
|
184
|
+
require_min_points=min_points_per_leaf is not None,
|
|
185
|
+
),
|
|
157
186
|
)
|
|
158
187
|
|
|
159
188
|
return bart
|
|
160
189
|
|
|
190
|
+
def _choose_suffstat_batch_size(size, y):
|
|
191
|
+
if size == 'auto':
|
|
192
|
+
try:
|
|
193
|
+
device = y.devices().pop()
|
|
194
|
+
except jax.errors.ConcretizationTypeError:
|
|
195
|
+
device = jax.devices()[0]
|
|
196
|
+
platform = device.platform
|
|
197
|
+
|
|
198
|
+
if platform == 'cpu':
|
|
199
|
+
return None
|
|
200
|
+
# maybe I should batch residuals (not counts) for numerical
|
|
201
|
+
# accuracy, even if it's slower
|
|
202
|
+
elif platform == 'gpu':
|
|
203
|
+
return 128 # 128 is good on A100, and V100 at high n
|
|
204
|
+
# 512 is good on T4, and V100 at low n
|
|
205
|
+
else:
|
|
206
|
+
raise KeyError(f'Unknown platform: {platform}')
|
|
207
|
+
|
|
208
|
+
elif size is not None:
|
|
209
|
+
return int(size)
|
|
210
|
+
|
|
211
|
+
return size
|
|
212
|
+
|
|
161
213
|
def step(bart, key):
|
|
162
214
|
"""
|
|
163
215
|
Perform one full MCMC step on a BART state.
|
|
@@ -196,11 +248,14 @@ def sample_trees(bart, key):
|
|
|
196
248
|
|
|
197
249
|
Notes
|
|
198
250
|
-----
|
|
199
|
-
This function zeroes the proposal counters.
|
|
251
|
+
This function zeroes the proposal counters before using them.
|
|
200
252
|
"""
|
|
253
|
+
bart = bart.copy()
|
|
201
254
|
key, subkey = random.split(key)
|
|
202
255
|
grow_moves, prune_moves = sample_moves(bart, subkey)
|
|
203
|
-
|
|
256
|
+
bart['var_trees'] = grow_moves['var_tree']
|
|
257
|
+
grow_leaf_indices = grove.traverse_forest(bart['X'], grow_moves['var_tree'], grow_moves['split_tree'])
|
|
258
|
+
return accept_moves_and_sample_leaves(bart, grow_moves, prune_moves, grow_leaf_indices, key)
|
|
204
259
|
|
|
205
260
|
def sample_moves(bart, key):
|
|
206
261
|
"""
|
|
@@ -216,20 +271,7 @@ def sample_moves(bart, key):
|
|
|
216
271
|
Returns
|
|
217
272
|
-------
|
|
218
273
|
grow_moves, prune_moves : dict
|
|
219
|
-
The proposals for grow and prune moves
|
|
220
|
-
|
|
221
|
-
'allowed' : bool array (num_trees,)
|
|
222
|
-
Whether the move is possible.
|
|
223
|
-
'node' : int array (num_trees,)
|
|
224
|
-
The index of the leaf to grow or node to prune.
|
|
225
|
-
'var_tree' : int array (num_trees, 2 ** (d - 1),)
|
|
226
|
-
The new decision axes of the tree.
|
|
227
|
-
'split_tree' : int array (num_trees, 2 ** (d - 1),)
|
|
228
|
-
The new decision boundaries of the tree.
|
|
229
|
-
'partial_ratio' : float array (num_trees,)
|
|
230
|
-
A factor of the Metropolis-Hastings ratio of the move. It lacks
|
|
231
|
-
the likelihood ratio, and the probability of proposing the prune
|
|
232
|
-
move. For the prune move, the ratio is inverted.
|
|
274
|
+
The proposals for grow and prune moves. See `grow_move` and `prune_move`.
|
|
233
275
|
"""
|
|
234
276
|
key = random.split(key, bart['var_trees'].shape[0])
|
|
235
277
|
return sample_moves_vmap_trees(bart['var_trees'], bart['split_trees'], bart['affluence_trees'], bart['max_split'], bart['p_nonterminal'], key)
|
|
@@ -260,7 +302,7 @@ def grow_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, ke
|
|
|
260
302
|
Whether a leaf has enough points to be grown.
|
|
261
303
|
max_split : array (p,)
|
|
262
304
|
The maximum split index for each variable.
|
|
263
|
-
p_nonterminal : array (d
|
|
305
|
+
p_nonterminal : array (d,)
|
|
264
306
|
The probability of a nonterminal node at each depth.
|
|
265
307
|
key : jax.dtypes.prng_key array
|
|
266
308
|
A jax random key.
|
|
@@ -292,16 +334,16 @@ def grow_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, ke
|
|
|
292
334
|
var_tree = var_tree.at[leaf_to_grow].set(var.astype(var_tree.dtype))
|
|
293
335
|
|
|
294
336
|
split = choose_split(var_tree, split_tree, max_split, leaf_to_grow, key2)
|
|
295
|
-
|
|
337
|
+
split_tree = split_tree.at[leaf_to_grow].set(split.astype(split_tree.dtype))
|
|
296
338
|
|
|
297
|
-
ratio = compute_partial_ratio(num_growable, num_prunable, p_nonterminal, leaf_to_grow, split_tree
|
|
339
|
+
ratio = compute_partial_ratio(num_growable, num_prunable, p_nonterminal, leaf_to_grow, split_tree)
|
|
298
340
|
|
|
299
341
|
return dict(
|
|
300
342
|
allowed=allowed,
|
|
301
343
|
node=leaf_to_grow,
|
|
302
|
-
var_tree=var_tree,
|
|
303
|
-
split_tree=new_split_tree,
|
|
304
344
|
partial_ratio=ratio,
|
|
345
|
+
var_tree=var_tree,
|
|
346
|
+
split_tree=split_tree,
|
|
305
347
|
)
|
|
306
348
|
|
|
307
349
|
def choose_leaf(split_tree, affluence_tree, key):
|
|
@@ -464,7 +506,7 @@ def ancestor_variables(var_tree, max_split, node_index):
|
|
|
464
506
|
the parent. Unused spots are filled with `p`.
|
|
465
507
|
"""
|
|
466
508
|
max_num_ancestors = grove.tree_depth(var_tree) - 1
|
|
467
|
-
ancestor_vars = jnp.zeros(max_num_ancestors,
|
|
509
|
+
ancestor_vars = jnp.zeros(max_num_ancestors, jaxext.minimal_unsigned_dtype(max_split.size))
|
|
468
510
|
carry = ancestor_vars.size - 1, node_index, ancestor_vars
|
|
469
511
|
def loop(carry, _):
|
|
470
512
|
i, index, ancestor_vars = carry
|
|
@@ -569,7 +611,7 @@ def choose_split(var_tree, split_tree, max_split, leaf_index, key):
|
|
|
569
611
|
l, r = split_range(var_tree, split_tree, max_split, leaf_index, var)
|
|
570
612
|
return random.randint(key, (), l, r)
|
|
571
613
|
|
|
572
|
-
def compute_partial_ratio(num_growable, num_prunable, p_nonterminal, leaf_to_grow,
|
|
614
|
+
def compute_partial_ratio(num_growable, num_prunable, p_nonterminal, leaf_to_grow, new_split_tree):
|
|
573
615
|
"""
|
|
574
616
|
Compute the product of the transition and prior ratios of a grow move.
|
|
575
617
|
|
|
@@ -580,12 +622,10 @@ def compute_partial_ratio(num_growable, num_prunable, p_nonterminal, leaf_to_gro
|
|
|
580
622
|
num_prunable : int
|
|
581
623
|
The number of leaf parents that could be pruned, after converting the
|
|
582
624
|
leaf to be grown to a non-terminal node.
|
|
583
|
-
p_nonterminal : array (d
|
|
625
|
+
p_nonterminal : array (d,)
|
|
584
626
|
The probability of a nonterminal node at each depth.
|
|
585
627
|
leaf_to_grow : int
|
|
586
628
|
The index of the leaf to grow.
|
|
587
|
-
initial_split_tree : array (2 ** (d - 1),)
|
|
588
|
-
The splitting points of the tree, before the leaf is grown.
|
|
589
629
|
new_split_tree : array (2 ** (d - 1),)
|
|
590
630
|
The splitting points of the tree, after the leaf is grown.
|
|
591
631
|
|
|
@@ -600,14 +640,18 @@ def compute_partial_ratio(num_growable, num_prunable, p_nonterminal, leaf_to_gro
|
|
|
600
640
|
# the two ratios also contain factors num_available_split *
|
|
601
641
|
# num_available_var, but they cancel out
|
|
602
642
|
|
|
603
|
-
|
|
604
|
-
|
|
643
|
+
prune_allowed = leaf_to_grow != 1
|
|
644
|
+
# prune allowed <---> the initial tree is not a root
|
|
645
|
+
# leaf to grow is root --> the tree can only be a root
|
|
646
|
+
# tree is a root --> the only leaf I can grow is root
|
|
647
|
+
|
|
648
|
+
p_grow = jnp.where(prune_allowed, 0.5, 1)
|
|
605
649
|
|
|
606
650
|
trans_ratio = num_growable / (p_grow * num_prunable)
|
|
607
651
|
|
|
608
|
-
depth = grove.tree_depths(
|
|
652
|
+
depth = grove.tree_depths(new_split_tree.size)[leaf_to_grow]
|
|
609
653
|
p_parent = p_nonterminal[depth]
|
|
610
|
-
cp_children = 1 - p_nonterminal
|
|
654
|
+
cp_children = 1 - p_nonterminal[depth + 1]
|
|
611
655
|
tree_ratio = cp_children * cp_children * p_parent / (1 - p_parent)
|
|
612
656
|
|
|
613
657
|
return trans_ratio * tree_ratio
|
|
@@ -626,7 +670,7 @@ def prune_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, k
|
|
|
626
670
|
Whether a leaf has enough points to be grown.
|
|
627
671
|
max_split : array (p,)
|
|
628
672
|
The maximum split index for each variable.
|
|
629
|
-
p_nonterminal : array (d
|
|
673
|
+
p_nonterminal : array (d,)
|
|
630
674
|
The probability of a nonterminal node at each depth.
|
|
631
675
|
key : jax.dtypes.prng_key array
|
|
632
676
|
A jax random key.
|
|
@@ -639,28 +683,20 @@ def prune_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, k
|
|
|
639
683
|
'allowed' : bool
|
|
640
684
|
Whether the move is possible.
|
|
641
685
|
'node' : int
|
|
642
|
-
The index of the
|
|
643
|
-
'var_tree' : array (2 ** (d - 1),)
|
|
644
|
-
The new decision axes of the tree.
|
|
645
|
-
'split_tree' : array (2 ** (d - 1),)
|
|
646
|
-
The new decision boundaries of the tree.
|
|
686
|
+
The index of the node to prune.
|
|
647
687
|
'partial_ratio' : float
|
|
648
688
|
A factor of the Metropolis-Hastings ratio of the move. It lacks
|
|
649
689
|
the likelihood ratio and the probability of proposing the prune
|
|
650
690
|
move. This ratio is inverted.
|
|
651
691
|
"""
|
|
652
692
|
node_to_prune, num_prunable, num_growable = choose_leaf_parent(split_tree, affluence_tree, key)
|
|
653
|
-
allowed =
|
|
654
|
-
|
|
655
|
-
new_split_tree = split_tree.at[node_to_prune].set(0)
|
|
693
|
+
allowed = split_tree[1].astype(bool) # allowed iff the tree is not a root
|
|
656
694
|
|
|
657
|
-
ratio = compute_partial_ratio(num_growable, num_prunable, p_nonterminal, node_to_prune,
|
|
695
|
+
ratio = compute_partial_ratio(num_growable, num_prunable, p_nonterminal, node_to_prune, split_tree)
|
|
658
696
|
|
|
659
697
|
return dict(
|
|
660
698
|
allowed=allowed,
|
|
661
699
|
node=node_to_prune,
|
|
662
|
-
var_tree=var_tree,
|
|
663
|
-
split_tree=new_split_tree,
|
|
664
700
|
partial_ratio=ratio, # it is inverted in accept_move_and_sample_leaves
|
|
665
701
|
)
|
|
666
702
|
|
|
@@ -702,29 +738,37 @@ def choose_leaf_parent(split_tree, affluence_tree, key):
|
|
|
702
738
|
|
|
703
739
|
return node_to_prune, num_prunable, num_growable
|
|
704
740
|
|
|
705
|
-
def
|
|
741
|
+
def accept_moves_and_sample_leaves(bart, grow_moves, prune_moves, grow_leaf_indices, key):
|
|
706
742
|
"""
|
|
707
|
-
|
|
743
|
+
Accept or reject the proposed moves and sample the new leaf values.
|
|
708
744
|
|
|
709
745
|
Parameters
|
|
710
746
|
----------
|
|
711
|
-
|
|
712
|
-
|
|
747
|
+
bart : dict
|
|
748
|
+
A BART mcmc state.
|
|
749
|
+
grow_moves : dict
|
|
750
|
+
The proposals for grow moves, batched over the first axis. See
|
|
751
|
+
`grow_move`.
|
|
752
|
+
prune_moves : dict
|
|
753
|
+
The proposals for prune moves, batched over the first axis. See
|
|
754
|
+
`prune_move`.
|
|
755
|
+
grow_leaf_indices : int array (num_trees, n)
|
|
756
|
+
The leaf indices of the trees proposed by the grow move.
|
|
757
|
+
key : jax.dtypes.prng_key array
|
|
758
|
+
A jax random key.
|
|
713
759
|
|
|
714
760
|
Returns
|
|
715
761
|
-------
|
|
716
|
-
|
|
717
|
-
|
|
762
|
+
bart : dict
|
|
763
|
+
The new BART mcmc state.
|
|
718
764
|
"""
|
|
719
|
-
return split_tree.at[1].get(mode='fill', fill_value=0).astype(bool)
|
|
720
|
-
|
|
721
|
-
def accept_moves_and_sample_leaves(bart, grow_moves, prune_moves, key):
|
|
722
765
|
bart = bart.copy()
|
|
723
766
|
def loop(carry, item):
|
|
724
767
|
resid = carry.pop('resid')
|
|
725
768
|
resid, carry, trees = accept_move_and_sample_leaves(
|
|
726
769
|
bart['X'],
|
|
727
770
|
len(bart['leaf_trees']),
|
|
771
|
+
bart['opt']['suffstat_batch_size'],
|
|
728
772
|
resid,
|
|
729
773
|
bart['sigma2'],
|
|
730
774
|
bart['min_points_per_leaf'],
|
|
@@ -740,11 +784,11 @@ def accept_moves_and_sample_leaves(bart, grow_moves, prune_moves, key):
|
|
|
740
784
|
carry['resid'] = bart['resid']
|
|
741
785
|
items = (
|
|
742
786
|
bart['leaf_trees'],
|
|
743
|
-
bart['var_trees'],
|
|
744
787
|
bart['split_trees'],
|
|
745
788
|
bart['affluence_trees'],
|
|
746
789
|
grow_moves,
|
|
747
790
|
prune_moves,
|
|
791
|
+
grow_leaf_indices,
|
|
748
792
|
random.split(key, len(bart['leaf_trees'])),
|
|
749
793
|
)
|
|
750
794
|
carry, trees = lax.scan(loop, carry, items)
|
|
@@ -752,11 +796,50 @@ def accept_moves_and_sample_leaves(bart, grow_moves, prune_moves, key):
|
|
|
752
796
|
bart.update(trees)
|
|
753
797
|
return bart
|
|
754
798
|
|
|
755
|
-
def accept_move_and_sample_leaves(X, ntree, resid, sigma2, min_points_per_leaf, counts, leaf_tree,
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
|
|
799
|
+
def accept_move_and_sample_leaves(X, ntree, suffstat_batch_size, resid, sigma2, min_points_per_leaf, counts, leaf_tree, split_tree, affluence_tree, grow_move, prune_move, grow_leaf_indices, key):
|
|
800
|
+
"""
|
|
801
|
+
Accept or reject a proposed move and sample the new leaf values.
|
|
802
|
+
|
|
803
|
+
Parameters
|
|
804
|
+
----------
|
|
805
|
+
X : int array (p, n)
|
|
806
|
+
The predictors.
|
|
807
|
+
ntree : int
|
|
808
|
+
The number of trees in the forest.
|
|
809
|
+
suffstat_batch_size : int, None
|
|
810
|
+
The batch size for computing sufficient statistics.
|
|
811
|
+
resid : float array (n,)
|
|
812
|
+
The residuals (data minus forest value).
|
|
813
|
+
sigma2 : float
|
|
814
|
+
The noise variance.
|
|
815
|
+
min_points_per_leaf : int or None
|
|
816
|
+
The minimum number of data points in a leaf node.
|
|
817
|
+
counts : dict
|
|
818
|
+
The acceptance counts from the mcmc state dict.
|
|
819
|
+
leaf_tree : float array (2 ** d,)
|
|
820
|
+
The leaf values of the tree.
|
|
821
|
+
split_tree : int array (2 ** (d - 1),)
|
|
822
|
+
The decision boundaries of the tree.
|
|
823
|
+
affluence_tree : bool array (2 ** (d - 1),) or None
|
|
824
|
+
Whether a leaf has enough points to be grown.
|
|
825
|
+
grow_move : dict
|
|
826
|
+
The proposal for the grow move. See `grow_move`.
|
|
827
|
+
prune_move : dict
|
|
828
|
+
The proposal for the prune move. See `prune_move`.
|
|
829
|
+
grow_leaf_indices : int array (n,)
|
|
830
|
+
The leaf indices of the tree proposed by the grow move.
|
|
831
|
+
key : jax.dtypes.prng_key array
|
|
832
|
+
A jax random key.
|
|
833
|
+
|
|
834
|
+
Returns
|
|
835
|
+
-------
|
|
836
|
+
resid : float array (n,)
|
|
837
|
+
The updated residuals (data minus forest value).
|
|
838
|
+
counts : dict
|
|
839
|
+
The updated acceptance counts.
|
|
840
|
+
trees : dict
|
|
841
|
+
The updated tree arrays.
|
|
842
|
+
"""
|
|
760
843
|
|
|
761
844
|
# compute leaf indices in starting tree
|
|
762
845
|
grow_node = grow_move['node']
|
|
@@ -782,10 +865,7 @@ def accept_move_and_sample_leaves(X, ntree, resid, sigma2, min_points_per_leaf,
|
|
|
782
865
|
resid += leaf_tree[leaf_indices]
|
|
783
866
|
|
|
784
867
|
# aggregate residuals and count units per leaf
|
|
785
|
-
grow_resid_tree =
|
|
786
|
-
grow_resid_tree = grow_resid_tree.at[grow_leaf_indices].add(resid)
|
|
787
|
-
grow_count_tree = jnp.zeros_like(leaf_tree, grove.minimal_unsigned_dtype(resid.size))
|
|
788
|
-
grow_count_tree = grow_count_tree.at[grow_leaf_indices].add(1)
|
|
868
|
+
grow_resid_tree, grow_count_tree = sufficient_stat(resid, grow_leaf_indices, leaf_tree.size, suffstat_batch_size)
|
|
789
869
|
|
|
790
870
|
# compute aggregations in starting tree
|
|
791
871
|
# I do not zero the children because garbage there does not matter
|
|
@@ -833,10 +913,10 @@ def accept_move_and_sample_leaves(X, ntree, resid, sigma2, min_points_per_leaf,
|
|
|
833
913
|
|
|
834
914
|
# pick trees for chosen move
|
|
835
915
|
trees = {}
|
|
836
|
-
var_tree = jnp.where(do_grow, grow_move['var_tree'], var_tree)
|
|
837
916
|
split_tree = jnp.where(do_grow, grow_move['split_tree'], split_tree)
|
|
838
|
-
|
|
839
|
-
split_tree =
|
|
917
|
+
# the prune var tree is equal to the initial one, because I leave garbage values behind
|
|
918
|
+
split_tree = split_tree.at[prune_node].set(
|
|
919
|
+
jnp.where(do_prune, 0, split_tree[prune_node]))
|
|
840
920
|
if min_points_per_leaf is not None:
|
|
841
921
|
affluence_tree = jnp.where(do_grow, grow_affluence_tree, affluence_tree)
|
|
842
922
|
affluence_tree = jnp.where(do_prune, prune_affluence_tree, affluence_tree)
|
|
@@ -869,13 +949,60 @@ def accept_move_and_sample_leaves(X, ntree, resid, sigma2, min_points_per_leaf,
|
|
|
869
949
|
# pack trees
|
|
870
950
|
trees = {
|
|
871
951
|
'leaf_trees': leaf_tree,
|
|
872
|
-
'var_trees': var_tree,
|
|
873
952
|
'split_trees': split_tree,
|
|
874
953
|
'affluence_trees': affluence_tree,
|
|
875
954
|
}
|
|
876
955
|
|
|
877
956
|
return resid, counts, trees
|
|
878
957
|
|
|
958
|
+
def sufficient_stat(resid, leaf_indices, tree_size, batch_size):
|
|
959
|
+
"""
|
|
960
|
+
Compute the sufficient statistics for the likelihood ratio of a tree move.
|
|
961
|
+
|
|
962
|
+
Parameters
|
|
963
|
+
----------
|
|
964
|
+
resid : float array (n,)
|
|
965
|
+
The residuals (data minus forest value).
|
|
966
|
+
leaf_indices : int array (n,)
|
|
967
|
+
The leaf indices of the tree (in which leaf each data point falls into).
|
|
968
|
+
tree_size : int
|
|
969
|
+
The size of the tree array (2 ** d).
|
|
970
|
+
batch_size : int, None
|
|
971
|
+
The batch size for the aggregation. Batching increases numerical
|
|
972
|
+
accuracy and parallelism.
|
|
973
|
+
|
|
974
|
+
Returns
|
|
975
|
+
-------
|
|
976
|
+
resid_tree : float array (2 ** d,)
|
|
977
|
+
The sum of the residuals at data points in each leaf.
|
|
978
|
+
count_tree : int array (2 ** d,)
|
|
979
|
+
The number of data points in each leaf.
|
|
980
|
+
"""
|
|
981
|
+
if batch_size is None:
|
|
982
|
+
aggr_func = _aggregate_scatter
|
|
983
|
+
else:
|
|
984
|
+
aggr_func = functools.partial(_aggregate_batched, batch_size=batch_size)
|
|
985
|
+
resid_tree = aggr_func(resid, leaf_indices, tree_size, jnp.float32)
|
|
986
|
+
count_tree = aggr_func(1, leaf_indices, tree_size, jnp.uint32)
|
|
987
|
+
return resid_tree, count_tree
|
|
988
|
+
|
|
989
|
+
def _aggregate_scatter(values, indices, size, dtype):
|
|
990
|
+
return (jnp
|
|
991
|
+
.zeros(size, dtype)
|
|
992
|
+
.at[indices]
|
|
993
|
+
.add(values)
|
|
994
|
+
)
|
|
995
|
+
|
|
996
|
+
def _aggregate_batched(values, indices, size, dtype, batch_size):
|
|
997
|
+
nbatches = indices.size // batch_size + bool(indices.size % batch_size)
|
|
998
|
+
batch_indices = jnp.arange(indices.size) // batch_size
|
|
999
|
+
return (jnp
|
|
1000
|
+
.zeros((nbatches, size), dtype)
|
|
1001
|
+
.at[batch_indices, indices]
|
|
1002
|
+
.add(values)
|
|
1003
|
+
.sum(axis=0)
|
|
1004
|
+
)
|
|
1005
|
+
|
|
879
1006
|
def compute_p_prune_back(new_split_tree, new_affluence_tree):
|
|
880
1007
|
"""
|
|
881
1008
|
Compute the probability of proposing a prune move after doing a grow move.
|
bartz/prepcovars.py
CHANGED
|
@@ -27,8 +27,10 @@ import functools
|
|
|
27
27
|
import jax
|
|
28
28
|
from jax import numpy as jnp
|
|
29
29
|
|
|
30
|
+
from . import jaxext
|
|
30
31
|
from . import grove
|
|
31
32
|
|
|
33
|
+
@functools.partial(jax.jit, static_argnums=(1,))
|
|
32
34
|
def quantilized_splits_from_matrix(X, max_bins):
|
|
33
35
|
"""
|
|
34
36
|
Determine bins that make the distribution of each predictor uniform.
|
|
@@ -52,48 +54,41 @@ def quantilized_splits_from_matrix(X, max_bins):
|
|
|
52
54
|
The number of actually used values in each row of `splits`.
|
|
53
55
|
"""
|
|
54
56
|
out_length = min(max_bins, X.shape[1]) - 1
|
|
55
|
-
return
|
|
57
|
+
# return _quantilized_splits_from_matrix(X, out_length)
|
|
58
|
+
@functools.partial(jaxext.autobatch, max_io_nbytes=500_000_000)
|
|
59
|
+
def func(X):
|
|
60
|
+
return _quantilized_splits_from_matrix(X, out_length)
|
|
61
|
+
return func(X)
|
|
56
62
|
|
|
57
63
|
@functools.partial(jax.vmap, in_axes=(0, None))
|
|
58
|
-
def
|
|
59
|
-
huge = huge_value(x)
|
|
60
|
-
u =
|
|
61
|
-
actual_length
|
|
62
|
-
|
|
64
|
+
def _quantilized_splits_from_matrix(x, out_length):
|
|
65
|
+
huge = jaxext.huge_value(x)
|
|
66
|
+
u, actual_length = jaxext.unique(x, size=x.size, fill_value=huge)
|
|
67
|
+
actual_length -= 1
|
|
68
|
+
if jnp.issubdtype(x.dtype, jnp.integer):
|
|
69
|
+
midpoints = u[:-1] + jaxext.ensure_unsigned(u[1:] - u[:-1]) // 2
|
|
70
|
+
indices = jnp.arange(midpoints.size, dtype=jaxext.minimal_unsigned_dtype(midpoints.size - 1))
|
|
71
|
+
midpoints = jnp.where(indices < actual_length, midpoints, huge)
|
|
72
|
+
else:
|
|
73
|
+
midpoints = (u[1:] + u[:-1]) / 2
|
|
63
74
|
indices = jnp.linspace(-1, actual_length, out_length + 2)[1:-1]
|
|
64
|
-
indices = jnp.around(indices).astype(
|
|
75
|
+
indices = jnp.around(indices).astype(jaxext.minimal_unsigned_dtype(midpoints.size - 1))
|
|
65
76
|
# indices calculation with float rather than int to avoid potential
|
|
66
77
|
# overflow with int32, and to round to nearest instead of rounding down
|
|
67
78
|
decimated_midpoints = midpoints[indices]
|
|
68
79
|
truncated_midpoints = midpoints[:out_length]
|
|
69
80
|
splits = jnp.where(actual_length > out_length, decimated_midpoints, truncated_midpoints)
|
|
70
81
|
max_split = jnp.minimum(actual_length, out_length)
|
|
71
|
-
max_split = max_split.astype(
|
|
82
|
+
max_split = max_split.astype(jaxext.minimal_unsigned_dtype(out_length))
|
|
72
83
|
return splits, max_split
|
|
73
84
|
|
|
74
|
-
|
|
75
|
-
"""
|
|
76
|
-
Return the maximum value that can be stored in `x`.
|
|
77
|
-
|
|
78
|
-
Parameters
|
|
79
|
-
----------
|
|
80
|
-
x : array
|
|
81
|
-
A numerical numpy or jax array.
|
|
82
|
-
|
|
83
|
-
Returns
|
|
84
|
-
-------
|
|
85
|
-
maxval : scalar
|
|
86
|
-
The maximum value allowed by `x`'s type (+inf for floats).
|
|
87
|
-
"""
|
|
88
|
-
if jnp.issubdtype(x.dtype, jnp.integer):
|
|
89
|
-
return jnp.iinfo(x.dtype).max
|
|
90
|
-
else:
|
|
91
|
-
return jnp.inf
|
|
92
|
-
|
|
85
|
+
@jax.jit
|
|
93
86
|
def bin_predictors(X, splits):
|
|
94
87
|
"""
|
|
95
88
|
Bin the predictors according to the given splits.
|
|
96
89
|
|
|
90
|
+
A value ``x`` is mapped to bin ``i`` iff ``splits[i - 1] < x <= splits[i]``.
|
|
91
|
+
|
|
97
92
|
Parameters
|
|
98
93
|
----------
|
|
99
94
|
X : array (p, n)
|
|
@@ -110,9 +105,9 @@ def bin_predictors(X, splits):
|
|
|
110
105
|
A matrix with `p` predictors and `n` observations, where each predictor
|
|
111
106
|
has been replaced by the index of the bin it falls into.
|
|
112
107
|
"""
|
|
113
|
-
return
|
|
108
|
+
return _bin_predictors(X, splits)
|
|
114
109
|
|
|
115
110
|
@jax.vmap
|
|
116
|
-
def
|
|
117
|
-
dtype =
|
|
111
|
+
def _bin_predictors(x, splits):
|
|
112
|
+
dtype = jaxext.minimal_unsigned_dtype(splits.size)
|
|
118
113
|
return jnp.searchsorted(splits, x).astype(dtype)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: bartz
|
|
3
|
-
Version: 0.1
|
|
3
|
+
Version: 0.2.1
|
|
4
4
|
Summary: A JAX implementation of BART
|
|
5
5
|
Home-page: https://github.com/Gattocrucco/bartz
|
|
6
6
|
License: MIT
|
|
@@ -20,7 +20,13 @@ Project-URL: Bug Tracker, https://github.com/Gattocrucco/bartz/issues
|
|
|
20
20
|
Project-URL: Repository, https://github.com/Gattocrucco/bartz
|
|
21
21
|
Description-Content-Type: text/markdown
|
|
22
22
|
|
|
23
|
+
[](https://pypi.org/project/bartz/)
|
|
24
|
+
|
|
23
25
|
# BART vectoriZed
|
|
24
26
|
|
|
25
27
|
A branchless vectorized implementation of Bayesian Additive Regression Trees (BART) in JAX.
|
|
26
28
|
|
|
29
|
+
BART is a nonparametric Bayesian regression technique. Given predictors $X$ and responses $y$, BART finds a function to predict $y$ given $X$. The result of the inference is a sample of possible functions, representing the uncertainty over the determination of the function.
|
|
30
|
+
|
|
31
|
+
This Python module provides an implementation of BART that runs on GPU, to process large datasets faster. It is also a good on CPU. Most other implementations of BART are for R, and run on CPU only.
|
|
32
|
+
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
bartz/BART.py,sha256=pRG7mALenknX2JHqY-VyhO9-evDgEC6hWBp4jpecBdM,15801
|
|
2
|
+
bartz/__init__.py,sha256=E96vsP0bZ8brejpZmEmRoXuMsUdinO_B_SKUUl1rLsg,1448
|
|
3
|
+
bartz/_version.py,sha256=PmcQ2PI2oP8irnLtJLJby2YfW6sBvLAmL-VpABzTqwc,22
|
|
4
|
+
bartz/debug.py,sha256=9ZH-JfwZVu5OPhHBEyXQHAU5H9KIu1vxLK7yNv4m4Ew,5314
|
|
5
|
+
bartz/grove.py,sha256=Wj_7jHl9w3uwuVdH4hoeXowimGpdRE2lGIzr4aDkzsI,8291
|
|
6
|
+
bartz/jaxext.py,sha256=VYA41D5F7DYcAAVtkcZtEN927HxQGOOQM-uGsgr2CPc,10996
|
|
7
|
+
bartz/mcmcloop.py,sha256=lheLrjVxmlyQzc_92zeNsFhdkrhEWQEjoAWFbVzknnw,7701
|
|
8
|
+
bartz/mcmcstep.py,sha256=6fzNMumXjMe6Fj6zoHLTf1D42JuAiQyGHfr6l1Bwrnk,39450
|
|
9
|
+
bartz/prepcovars.py,sha256=iiQ0WjSj4--l5DgPW626Qg2SSB6ljnaaUsBz_A8kFrI,4634
|
|
10
|
+
bartz-0.2.1.dist-info/LICENSE,sha256=heuIJZQK9IexJYC-fYHoLUrgj8HG8yS3G072EvKh-94,1073
|
|
11
|
+
bartz-0.2.1.dist-info/METADATA,sha256=eGxicC1iR-Bpjk1uKn50g6FxdFfq9S70nl7m5GmXO14,1490
|
|
12
|
+
bartz-0.2.1.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
|
13
|
+
bartz-0.2.1.dist-info/RECORD,,
|
bartz-0.1.0.dist-info/RECORD
DELETED
|
@@ -1,13 +0,0 @@
|
|
|
1
|
-
bartz/__init__.py,sha256=40tX5XHoTiGnZcoeogVpyNOM_5rbHt-Y6zTI0NS7OA4,1345
|
|
2
|
-
bartz/_version.py,sha256=IMjkMO3twhQzluVTo8Z6rE7Eg-9U79_LGKMcsWLKBkY,22
|
|
3
|
-
bartz/debug.py,sha256=_HOjDieipAgliP6B6C0UMgz-mVgmeZ3zmtzVe-iMGtY,5289
|
|
4
|
-
bartz/grove.py,sha256=LHhnvNKLb-jxUf4YjP927Hf9txkXynhMZ2ejtMRWZl4,8353
|
|
5
|
-
bartz/interface.py,sha256=INyNuHzFySwXAsXofVZDpTsMv78AR_3VCvAHbZFh92c,15724
|
|
6
|
-
bartz/jaxext.py,sha256=FK5j1zfW1yR4-yPKcD7ZvKSkVQ5--jHjQpVCl4n4gXY,2844
|
|
7
|
-
bartz/mcmcloop.py,sha256=xTxC1AkNX8jCrMArblvlMjnjMh80q1M3a6ZGrDdfsFI,7423
|
|
8
|
-
bartz/mcmcstep.py,sha256=6zkpTqgIrapeVy9mhy6BlsIO0s26HwBRDfw_6dVMmZA,35207
|
|
9
|
-
bartz/prepcovars.py,sha256=3ddDOtNNop3Ba2Kgy_dZ6apFydtwaEXH3uXSmmKf9Fs,4421
|
|
10
|
-
bartz-0.1.0.dist-info/LICENSE,sha256=heuIJZQK9IexJYC-fYHoLUrgj8HG8yS3G072EvKh-94,1073
|
|
11
|
-
bartz-0.1.0.dist-info/METADATA,sha256=8YYlbCf7frDtT2of6tNlnBbuGqyO8YyYlED8OXSiBpA,933
|
|
12
|
-
bartz-0.1.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
|
13
|
-
bartz-0.1.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|