bartz 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bartz/{interface.py → BART.py} +4 -3
- bartz/__init__.py +6 -1
- bartz/_version.py +1 -1
- bartz/debug.py +5 -4
- bartz/grove.py +36 -36
- bartz/jaxext.py +261 -5
- bartz/mcmcloop.py +19 -11
- bartz/mcmcstep.py +192 -73
- bartz/prepcovars.py +25 -30
- {bartz-0.1.0.dist-info → bartz-0.2.0.dist-info}/METADATA +7 -1
- bartz-0.2.0.dist-info/RECORD +13 -0
- bartz-0.1.0.dist-info/RECORD +0 -13
- {bartz-0.1.0.dist-info → bartz-0.2.0.dist-info}/LICENSE +0 -0
- {bartz-0.1.0.dist-info → bartz-0.2.0.dist-info}/WHEEL +0 -0
bartz/{interface.py → BART.py}
RENAMED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# bartz/src/bartz/
|
|
1
|
+
# bartz/src/bartz/BART.py
|
|
2
2
|
#
|
|
3
3
|
# Copyright (c) 2024, Giacomo Petrillo
|
|
4
4
|
#
|
|
@@ -33,7 +33,7 @@ from . import mcmcstep
|
|
|
33
33
|
from . import mcmcloop
|
|
34
34
|
from . import prepcovars
|
|
35
35
|
|
|
36
|
-
class
|
|
36
|
+
class gbart:
|
|
37
37
|
"""
|
|
38
38
|
Nonparametric regression with Bayesian Additive Regression Trees (BART).
|
|
39
39
|
|
|
@@ -133,7 +133,7 @@ class BART:
|
|
|
133
133
|
|
|
134
134
|
Notes
|
|
135
135
|
-----
|
|
136
|
-
This interface imitates the function `
|
|
136
|
+
This interface imitates the function `gbart` from the R package `BART
|
|
137
137
|
<https://cran.r-project.org/package=BART>`_, but with these differences:
|
|
138
138
|
|
|
139
139
|
- If `x_train` and `x_test` are matrices, they have one predictor per row
|
|
@@ -142,6 +142,7 @@ class BART:
|
|
|
142
142
|
- `usequants` is always `True`.
|
|
143
143
|
- `rm_const` is always `False`.
|
|
144
144
|
- The default `numcut` is 255 instead of 100.
|
|
145
|
+
- A lot of functionality is missing (variable selection, discrete response).
|
|
145
146
|
- There are some additional attributes, and some missing.
|
|
146
147
|
"""
|
|
147
148
|
|
bartz/__init__.py
CHANGED
|
@@ -30,6 +30,11 @@ See the manual at https://gattocrucco.github.io/bartz/docs
|
|
|
30
30
|
|
|
31
31
|
from ._version import __version__
|
|
32
32
|
|
|
33
|
-
from .
|
|
33
|
+
from . import BART
|
|
34
34
|
|
|
35
35
|
from . import debug
|
|
36
|
+
from . import grove
|
|
37
|
+
from . import mcmcstep
|
|
38
|
+
from . import mcmcloop
|
|
39
|
+
from . import prepcovars
|
|
40
|
+
from . import jaxext
|
bartz/_version.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = '0.
|
|
1
|
+
__version__ = '0.2.0'
|
bartz/debug.py
CHANGED
|
@@ -6,6 +6,7 @@ from jax import lax
|
|
|
6
6
|
|
|
7
7
|
from . import grove
|
|
8
8
|
from . import mcmcstep
|
|
9
|
+
from . import jaxext
|
|
9
10
|
|
|
10
11
|
def print_tree(leaf_tree, var_tree, split_tree, print_all=False):
|
|
11
12
|
|
|
@@ -83,7 +84,7 @@ def trace_depth_distr(split_trees_trace):
|
|
|
83
84
|
def points_per_leaf_distr(var_tree, split_tree, X):
|
|
84
85
|
traverse_tree = jax.vmap(grove.traverse_tree, in_axes=(1, None, None))
|
|
85
86
|
indices = traverse_tree(X, var_tree, split_tree)
|
|
86
|
-
count_tree = jnp.zeros(2 * split_tree.size, dtype=
|
|
87
|
+
count_tree = jnp.zeros(2 * split_tree.size, dtype=jaxext.minimal_unsigned_dtype(indices.size))
|
|
87
88
|
count_tree = count_tree.at[indices].add(1)
|
|
88
89
|
is_leaf = grove.is_actual_leaf(split_tree, add_bottom_level=True).view(jnp.uint8)
|
|
89
90
|
return jnp.bincount(count_tree, is_leaf, length=X.shape[1] + 1)
|
|
@@ -103,7 +104,7 @@ def trace_points_per_leaf_distr(bart, X):
|
|
|
103
104
|
return distr
|
|
104
105
|
|
|
105
106
|
def check_types(leaf_tree, var_tree, split_tree, max_split):
|
|
106
|
-
expected_var_dtype =
|
|
107
|
+
expected_var_dtype = jaxext.minimal_unsigned_dtype(max_split.size - 1)
|
|
107
108
|
expected_split_dtype = max_split.dtype
|
|
108
109
|
return var_tree.dtype == expected_var_dtype and split_tree.dtype == expected_split_dtype
|
|
109
110
|
|
|
@@ -117,7 +118,7 @@ def check_leaf_values(leaf_tree, var_tree, split_tree, max_split):
|
|
|
117
118
|
return jnp.all(jnp.isfinite(leaf_tree))
|
|
118
119
|
|
|
119
120
|
def check_stray_nodes(leaf_tree, var_tree, split_tree, max_split):
|
|
120
|
-
index = jnp.arange(2 * split_tree.size, dtype=
|
|
121
|
+
index = jnp.arange(2 * split_tree.size, dtype=jaxext.minimal_unsigned_dtype(2 * split_tree.size - 1))
|
|
121
122
|
parent_index = index >> 1
|
|
122
123
|
is_not_leaf = split_tree.at[index].get(mode='fill', fill_value=0) != 0
|
|
123
124
|
parent_is_leaf = split_tree[parent_index] == 0
|
|
@@ -134,7 +135,7 @@ check_functions = [
|
|
|
134
135
|
]
|
|
135
136
|
|
|
136
137
|
def check_tree(leaf_tree, var_tree, split_tree, max_split):
|
|
137
|
-
error_type =
|
|
138
|
+
error_type = jaxext.minimal_unsigned_dtype(2 ** len(check_functions) - 1)
|
|
138
139
|
error = error_type(0)
|
|
139
140
|
for i, func in enumerate(check_functions):
|
|
140
141
|
ok = func(leaf_tree, var_tree, split_tree, max_split)
|
bartz/grove.py
CHANGED
|
@@ -44,7 +44,6 @@ import functools
|
|
|
44
44
|
import math
|
|
45
45
|
|
|
46
46
|
import jax
|
|
47
|
-
|
|
48
47
|
from jax import numpy as jnp
|
|
49
48
|
from jax import lax
|
|
50
49
|
|
|
@@ -107,29 +106,47 @@ def traverse_tree(x, var_tree, split_tree):
|
|
|
107
106
|
|
|
108
107
|
carry = (
|
|
109
108
|
jnp.zeros((), bool),
|
|
110
|
-
jnp.ones((), minimal_unsigned_dtype(2 * var_tree.size - 1)),
|
|
109
|
+
jnp.ones((), jaxext.minimal_unsigned_dtype(2 * var_tree.size - 1)),
|
|
111
110
|
)
|
|
112
111
|
|
|
113
112
|
def loop(carry, _):
|
|
114
113
|
leaf_found, index = carry
|
|
115
114
|
|
|
116
|
-
split = split_tree
|
|
117
|
-
var = var_tree
|
|
115
|
+
split = split_tree[index]
|
|
116
|
+
var = var_tree[index]
|
|
118
117
|
|
|
119
|
-
leaf_found |=
|
|
118
|
+
leaf_found |= split == 0
|
|
120
119
|
child_index = (index << 1) + (x[var] >= split)
|
|
121
120
|
index = jnp.where(leaf_found, index, child_index)
|
|
122
121
|
|
|
123
122
|
return (leaf_found, index), None
|
|
124
123
|
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
# - separate and special-case the last iteration
|
|
128
|
-
|
|
129
|
-
depth = 1 + tree_depth(var_tree)
|
|
130
|
-
(_, index), _ = lax.scan(loop, carry, None, depth)
|
|
124
|
+
depth = tree_depth(var_tree)
|
|
125
|
+
(_, index), _ = lax.scan(loop, carry, None, depth, unroll=16)
|
|
131
126
|
return index
|
|
132
127
|
|
|
128
|
+
@functools.partial(jaxext.vmap_nodoc, in_axes=(None, 0, 0))
|
|
129
|
+
@functools.partial(jaxext.vmap_nodoc, in_axes=(1, None, None))
|
|
130
|
+
def traverse_forest(X, var_trees, split_trees):
|
|
131
|
+
"""
|
|
132
|
+
Find the leaves where points fall into.
|
|
133
|
+
|
|
134
|
+
Parameters
|
|
135
|
+
----------
|
|
136
|
+
X : array (p, n)
|
|
137
|
+
The coordinates to evaluate the trees at.
|
|
138
|
+
var_trees : array (m, 2 ** (d - 1))
|
|
139
|
+
The decision axes of the trees.
|
|
140
|
+
split_trees : array (m, 2 ** (d - 1))
|
|
141
|
+
The decision boundaries of the trees.
|
|
142
|
+
|
|
143
|
+
Returns
|
|
144
|
+
-------
|
|
145
|
+
indices : array (m, n)
|
|
146
|
+
The indices of the leaves.
|
|
147
|
+
"""
|
|
148
|
+
return traverse_tree(X, var_trees, split_trees)
|
|
149
|
+
|
|
133
150
|
def evaluate_forest(X, leaf_trees, var_trees, split_trees, dtype):
|
|
134
151
|
"""
|
|
135
152
|
Evaluate a ensemble of trees at an array of points.
|
|
@@ -138,7 +155,7 @@ def evaluate_forest(X, leaf_trees, var_trees, split_trees, dtype):
|
|
|
138
155
|
----------
|
|
139
156
|
X : array (p, n)
|
|
140
157
|
The coordinates to evaluate the trees at.
|
|
141
|
-
leaf_trees : (m, 2 ** d)
|
|
158
|
+
leaf_trees : array (m, 2 ** d)
|
|
142
159
|
The leaf values of the tree or forest. If the input is a forest, the
|
|
143
160
|
first axis is the tree index, and the values are summed.
|
|
144
161
|
var_trees : array (m, 2 ** (d - 1))
|
|
@@ -153,30 +170,13 @@ def evaluate_forest(X, leaf_trees, var_trees, split_trees, dtype):
|
|
|
153
170
|
out : array (n,)
|
|
154
171
|
The sum of the values of the trees at the points in `X`.
|
|
155
172
|
"""
|
|
156
|
-
indices =
|
|
173
|
+
indices = traverse_forest(X, var_trees, split_trees)
|
|
157
174
|
ntree, _ = leaf_trees.shape
|
|
158
|
-
tree_index = jnp.arange(ntree, dtype=minimal_unsigned_dtype(ntree - 1))[:, None]
|
|
175
|
+
tree_index = jnp.arange(ntree, dtype=jaxext.minimal_unsigned_dtype(ntree - 1))[:, None]
|
|
159
176
|
leaves = leaf_trees[tree_index, indices]
|
|
160
177
|
return jnp.sum(leaves, axis=0, dtype=dtype)
|
|
161
|
-
# this sum suggests to swap the vmaps, but I think it's better for X
|
|
162
|
-
|
|
163
|
-
@functools.partial(jax.vmap, in_axes=(None, 0, 0))
|
|
164
|
-
@functools.partial(jax.vmap, in_axes=(1, None, None))
|
|
165
|
-
def _traverse_forest(X, var_trees, split_trees):
|
|
166
|
-
return traverse_tree(X, var_trees, split_trees)
|
|
167
|
-
|
|
168
|
-
def minimal_unsigned_dtype(max_value):
|
|
169
|
-
"""
|
|
170
|
-
Return the smallest unsigned integer dtype that can represent a given
|
|
171
|
-
maximum value.
|
|
172
|
-
"""
|
|
173
|
-
if max_value < 2 ** 8:
|
|
174
|
-
return jnp.uint8
|
|
175
|
-
if max_value < 2 ** 16:
|
|
176
|
-
return jnp.uint16
|
|
177
|
-
if max_value < 2 ** 32:
|
|
178
|
-
return jnp.uint32
|
|
179
|
-
return jnp.uint64
|
|
178
|
+
# this sum suggests to swap the vmaps, but I think it's better for X
|
|
179
|
+
# copying to keep it that way
|
|
180
180
|
|
|
181
181
|
def is_actual_leaf(split_tree, *, add_bottom_level=False):
|
|
182
182
|
"""
|
|
@@ -200,7 +200,7 @@ def is_actual_leaf(split_tree, *, add_bottom_level=False):
|
|
|
200
200
|
if add_bottom_level:
|
|
201
201
|
size *= 2
|
|
202
202
|
is_leaf = jnp.concatenate([is_leaf, jnp.ones_like(is_leaf)])
|
|
203
|
-
index = jnp.arange(size, dtype=minimal_unsigned_dtype(size - 1))
|
|
203
|
+
index = jnp.arange(size, dtype=jaxext.minimal_unsigned_dtype(size - 1))
|
|
204
204
|
parent_index = index >> 1
|
|
205
205
|
parent_nonleaf = split_tree[parent_index].astype(bool)
|
|
206
206
|
parent_nonleaf = parent_nonleaf.at[1].set(True)
|
|
@@ -220,7 +220,7 @@ def is_leaves_parent(split_tree):
|
|
|
220
220
|
is_leaves_parent : bool array (2 ** (d - 1),)
|
|
221
221
|
The mask indicating which nodes have leaf children.
|
|
222
222
|
"""
|
|
223
|
-
index = jnp.arange(split_tree.size, dtype=minimal_unsigned_dtype(2 * split_tree.size - 1))
|
|
223
|
+
index = jnp.arange(split_tree.size, dtype=jaxext.minimal_unsigned_dtype(2 * split_tree.size - 1))
|
|
224
224
|
left_index = index << 1 # left child
|
|
225
225
|
right_index = left_index + 1 # right child
|
|
226
226
|
left_leaf = split_tree.at[left_index].get(mode='fill', fill_value=0) == 0
|
|
@@ -252,4 +252,4 @@ def tree_depths(tree_length):
|
|
|
252
252
|
depth += 1
|
|
253
253
|
depths.append(depth - 1)
|
|
254
254
|
depths[0] = 0
|
|
255
|
-
return jnp.array(depths, minimal_unsigned_dtype(max(depths)))
|
|
255
|
+
return jnp.array(depths, jaxext.minimal_unsigned_dtype(max(depths)))
|
bartz/jaxext.py
CHANGED
|
@@ -10,10 +10,10 @@
|
|
|
10
10
|
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
11
11
|
# copies of the Software, and to permit persons to whom the Software is
|
|
12
12
|
# furnished to do so, subject to the following conditions:
|
|
13
|
-
#
|
|
13
|
+
#
|
|
14
14
|
# The above copyright notice and this permission notice shall be included in all
|
|
15
15
|
# copies or substantial portions of the Software.
|
|
16
|
-
#
|
|
16
|
+
#
|
|
17
17
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
18
18
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
19
19
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
@@ -23,12 +23,19 @@
|
|
|
23
23
|
# SOFTWARE.
|
|
24
24
|
|
|
25
25
|
import functools
|
|
26
|
+
import math
|
|
27
|
+
import warnings
|
|
26
28
|
|
|
27
29
|
from scipy import special
|
|
28
30
|
import jax
|
|
29
31
|
from jax import numpy as jnp
|
|
32
|
+
from jax import tree_util
|
|
33
|
+
from jax import lax
|
|
30
34
|
|
|
31
35
|
def float_type(*args):
|
|
36
|
+
"""
|
|
37
|
+
Determine the jax floating point result type given operands/types.
|
|
38
|
+
"""
|
|
32
39
|
t = jnp.result_type(*args)
|
|
33
40
|
return jnp.sin(jnp.empty(0, t)).dtype
|
|
34
41
|
|
|
@@ -39,8 +46,8 @@ def castto(func, type):
|
|
|
39
46
|
return newfunc
|
|
40
47
|
|
|
41
48
|
def pure_callback_ufunc(callback, dtype, *args, excluded=None, **kwargs):
|
|
42
|
-
""" version of jax.pure_callback that deals correctly with ufuncs,
|
|
43
|
-
see https://github.com/google/jax/issues/17187 """
|
|
49
|
+
""" version of `jax.pure_callback` that deals correctly with ufuncs,
|
|
50
|
+
see `<https://github.com/google/jax/issues/17187>`_ """
|
|
44
51
|
if excluded is None:
|
|
45
52
|
excluded = ()
|
|
46
53
|
shape = jnp.broadcast_shapes(*(
|
|
@@ -63,6 +70,7 @@ class scipy:
|
|
|
63
70
|
|
|
64
71
|
class special:
|
|
65
72
|
|
|
73
|
+
@functools.wraps(special.gammainccinv)
|
|
66
74
|
def gammainccinv(a, y):
|
|
67
75
|
a = jnp.asarray(a)
|
|
68
76
|
y = jnp.asarray(y)
|
|
@@ -73,13 +81,261 @@ class scipy:
|
|
|
73
81
|
class stats:
|
|
74
82
|
|
|
75
83
|
class invgamma:
|
|
76
|
-
|
|
84
|
+
|
|
77
85
|
def ppf(q, a):
|
|
78
86
|
return 1 / scipy.special.gammainccinv(a, q)
|
|
79
87
|
|
|
80
88
|
@functools.wraps(jax.vmap)
|
|
81
89
|
def vmap_nodoc(fun, *args, **kw):
|
|
90
|
+
"""
|
|
91
|
+
Version of `jax.vmap` that preserves the docstring of the input function.
|
|
92
|
+
"""
|
|
82
93
|
doc = fun.__doc__
|
|
83
94
|
fun = jax.vmap(fun, *args, **kw)
|
|
84
95
|
fun.__doc__ = doc
|
|
85
96
|
return fun
|
|
97
|
+
|
|
98
|
+
def huge_value(x):
|
|
99
|
+
"""
|
|
100
|
+
Return the maximum value that can be stored in `x`.
|
|
101
|
+
|
|
102
|
+
Parameters
|
|
103
|
+
----------
|
|
104
|
+
x : array
|
|
105
|
+
A numerical numpy or jax array.
|
|
106
|
+
|
|
107
|
+
Returns
|
|
108
|
+
-------
|
|
109
|
+
maxval : scalar
|
|
110
|
+
The maximum value allowed by `x`'s type (+inf for floats).
|
|
111
|
+
"""
|
|
112
|
+
if jnp.issubdtype(x.dtype, jnp.integer):
|
|
113
|
+
return jnp.iinfo(x.dtype).max
|
|
114
|
+
else:
|
|
115
|
+
return jnp.inf
|
|
116
|
+
|
|
117
|
+
def minimal_unsigned_dtype(max_value):
|
|
118
|
+
"""
|
|
119
|
+
Return the smallest unsigned integer dtype that can represent a given
|
|
120
|
+
maximum value (inclusive).
|
|
121
|
+
"""
|
|
122
|
+
if max_value < 2 ** 8:
|
|
123
|
+
return jnp.uint8
|
|
124
|
+
if max_value < 2 ** 16:
|
|
125
|
+
return jnp.uint16
|
|
126
|
+
if max_value < 2 ** 32:
|
|
127
|
+
return jnp.uint32
|
|
128
|
+
return jnp.uint64
|
|
129
|
+
|
|
130
|
+
def signed_to_unsigned(int_dtype):
|
|
131
|
+
"""
|
|
132
|
+
Map a signed integer type to its unsigned counterpart. Unsigned types are
|
|
133
|
+
passed through.
|
|
134
|
+
"""
|
|
135
|
+
assert jnp.issubdtype(int_dtype, jnp.integer)
|
|
136
|
+
if jnp.issubdtype(int_dtype, jnp.unsignedinteger):
|
|
137
|
+
return int_dtype
|
|
138
|
+
if int_dtype == jnp.int8:
|
|
139
|
+
return jnp.uint8
|
|
140
|
+
if int_dtype == jnp.int16:
|
|
141
|
+
return jnp.uint16
|
|
142
|
+
if int_dtype == jnp.int32:
|
|
143
|
+
return jnp.uint32
|
|
144
|
+
if int_dtype == jnp.int64:
|
|
145
|
+
return jnp.uint64
|
|
146
|
+
|
|
147
|
+
def ensure_unsigned(x):
|
|
148
|
+
"""
|
|
149
|
+
If x has signed integer type, cast it to the unsigned dtype of the same size.
|
|
150
|
+
"""
|
|
151
|
+
return x.astype(signed_to_unsigned(x.dtype))
|
|
152
|
+
|
|
153
|
+
@functools.partial(jax.jit, static_argnums=(1,))
|
|
154
|
+
def unique(x, size, fill_value):
|
|
155
|
+
"""
|
|
156
|
+
Restricted version of `jax.numpy.unique` that uses less memory.
|
|
157
|
+
|
|
158
|
+
Parameters
|
|
159
|
+
----------
|
|
160
|
+
x : 1d array
|
|
161
|
+
The input array.
|
|
162
|
+
size : int
|
|
163
|
+
The length of the output.
|
|
164
|
+
fill_value : scalar
|
|
165
|
+
The value to fill the output with if `size` is greater than the number
|
|
166
|
+
of unique values in `x`.
|
|
167
|
+
|
|
168
|
+
Returns
|
|
169
|
+
-------
|
|
170
|
+
out : array (size,)
|
|
171
|
+
The unique values in `x`, sorted, and right-padded with `fill_value`.
|
|
172
|
+
actual_length : int
|
|
173
|
+
The number of used values in `out`.
|
|
174
|
+
"""
|
|
175
|
+
if x.size == 0:
|
|
176
|
+
return jnp.full(size, fill_value, x.dtype), 0
|
|
177
|
+
if size == 0:
|
|
178
|
+
return jnp.empty(0, x.dtype), 0
|
|
179
|
+
x = jnp.sort(x)
|
|
180
|
+
def loop(carry, x):
|
|
181
|
+
i_out, i_in, last, out = carry
|
|
182
|
+
i_out = jnp.where(x == last, i_out, i_out + 1)
|
|
183
|
+
out = out.at[i_out].set(x)
|
|
184
|
+
return (i_out, i_in + 1, x, out), None
|
|
185
|
+
carry = 0, 0, x[0], jnp.full(size, fill_value, x.dtype)
|
|
186
|
+
(actual_length, _, _, out), _ = jax.lax.scan(loop, carry, x[:size])
|
|
187
|
+
return out, actual_length + 1
|
|
188
|
+
|
|
189
|
+
def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False):
|
|
190
|
+
"""
|
|
191
|
+
Batch a function such that each batch is smaller than a threshold.
|
|
192
|
+
|
|
193
|
+
Parameters
|
|
194
|
+
----------
|
|
195
|
+
func : callable
|
|
196
|
+
A jittable function with positional arguments only, with inputs and
|
|
197
|
+
outputs pytrees of arrays.
|
|
198
|
+
max_io_nbytes : int
|
|
199
|
+
The maximum number of input + output bytes in each batch.
|
|
200
|
+
in_axes : pytree of ints, default 0
|
|
201
|
+
A tree matching the structure of the function input, indicating along
|
|
202
|
+
which axes each array should be batched. If a single integer, it is
|
|
203
|
+
used for all arrays.
|
|
204
|
+
out_axes : pytree of ints, default 0
|
|
205
|
+
The same for outputs.
|
|
206
|
+
return_nbatches : bool, default False
|
|
207
|
+
If True, the number of batches is returned as a second output.
|
|
208
|
+
|
|
209
|
+
Returns
|
|
210
|
+
-------
|
|
211
|
+
batched_func : callable
|
|
212
|
+
A function with the same signature as `func`, but that processes the
|
|
213
|
+
input and output in batches in a loop.
|
|
214
|
+
"""
|
|
215
|
+
|
|
216
|
+
def expand_axes(axes, tree):
|
|
217
|
+
if isinstance(axes, int):
|
|
218
|
+
return tree_util.tree_map(lambda _: axes, tree)
|
|
219
|
+
return tree_util.tree_map(lambda _, axis: axis, tree, axes)
|
|
220
|
+
|
|
221
|
+
def extract_size(axes, tree):
|
|
222
|
+
sizes = tree_util.tree_map(lambda x, axis: x.shape[axis], tree, axes)
|
|
223
|
+
sizes, _ = tree_util.tree_flatten(sizes)
|
|
224
|
+
assert all(s == sizes[0] for s in sizes)
|
|
225
|
+
return sizes[0]
|
|
226
|
+
|
|
227
|
+
def sum_nbytes(tree):
|
|
228
|
+
def nbytes(x):
|
|
229
|
+
return math.prod(x.shape) * x.dtype.itemsize
|
|
230
|
+
return tree_util.tree_reduce(lambda size, x: size + nbytes(x), tree, 0)
|
|
231
|
+
|
|
232
|
+
def next_divisor_small(dividend, min_divisor):
|
|
233
|
+
for divisor in range(min_divisor, int(math.sqrt(dividend)) + 1):
|
|
234
|
+
if dividend % divisor == 0:
|
|
235
|
+
return divisor
|
|
236
|
+
return dividend
|
|
237
|
+
|
|
238
|
+
def next_divisor_large(dividend, min_divisor):
|
|
239
|
+
max_inv_divisor = dividend // min_divisor
|
|
240
|
+
for inv_divisor in range(max_inv_divisor, 0, -1):
|
|
241
|
+
if dividend % inv_divisor == 0:
|
|
242
|
+
return dividend // inv_divisor
|
|
243
|
+
return dividend
|
|
244
|
+
|
|
245
|
+
def next_divisor(dividend, min_divisor):
|
|
246
|
+
if min_divisor * min_divisor <= dividend:
|
|
247
|
+
return next_divisor_small(dividend, min_divisor)
|
|
248
|
+
return next_divisor_large(dividend, min_divisor)
|
|
249
|
+
|
|
250
|
+
def move_axes_out(axes, tree):
|
|
251
|
+
def move_axis_out(axis, x):
|
|
252
|
+
if axis != 0:
|
|
253
|
+
return jnp.moveaxis(x, axis, 0)
|
|
254
|
+
return x
|
|
255
|
+
return tree_util.tree_map(move_axis_out, axes, tree)
|
|
256
|
+
|
|
257
|
+
def move_axes_in(axes, tree):
|
|
258
|
+
def move_axis_in(axis, x):
|
|
259
|
+
if axis != 0:
|
|
260
|
+
return jnp.moveaxis(x, 0, axis)
|
|
261
|
+
return x
|
|
262
|
+
return tree_util.tree_map(move_axis_in, axes, tree)
|
|
263
|
+
|
|
264
|
+
def batch(tree, nbatches):
|
|
265
|
+
def batch(x):
|
|
266
|
+
return x.reshape((nbatches, x.shape[0] // nbatches) + x.shape[1:])
|
|
267
|
+
return tree_util.tree_map(batch, tree)
|
|
268
|
+
|
|
269
|
+
def unbatch(tree):
|
|
270
|
+
def unbatch(x):
|
|
271
|
+
return x.reshape((x.shape[0] * x.shape[1],) + x.shape[2:])
|
|
272
|
+
return tree_util.tree_map(unbatch, tree)
|
|
273
|
+
|
|
274
|
+
def check_same(tree1, tree2):
|
|
275
|
+
def check_same(x1, x2):
|
|
276
|
+
assert x1.shape == x2.shape
|
|
277
|
+
assert x1.dtype == x2.dtype
|
|
278
|
+
tree_util.tree_map(check_same, tree1, tree2)
|
|
279
|
+
|
|
280
|
+
initial_in_axes = in_axes
|
|
281
|
+
initial_out_axes = out_axes
|
|
282
|
+
|
|
283
|
+
@jax.jit
|
|
284
|
+
@functools.wraps(func)
|
|
285
|
+
def batched_func(*args):
|
|
286
|
+
example_result = jax.eval_shape(func, *args)
|
|
287
|
+
|
|
288
|
+
in_axes = expand_axes(initial_in_axes, args)
|
|
289
|
+
out_axes = expand_axes(initial_out_axes, example_result)
|
|
290
|
+
|
|
291
|
+
in_size = extract_size(in_axes, args)
|
|
292
|
+
out_size = extract_size(out_axes, example_result)
|
|
293
|
+
assert in_size == out_size
|
|
294
|
+
size = in_size
|
|
295
|
+
|
|
296
|
+
total_nbytes = sum_nbytes(args) + sum_nbytes(example_result)
|
|
297
|
+
min_nbatches = total_nbytes // max_io_nbytes + bool(total_nbytes % max_io_nbytes)
|
|
298
|
+
nbatches = next_divisor(size, min_nbatches)
|
|
299
|
+
assert 1 <= nbatches <= size
|
|
300
|
+
assert size % nbatches == 0
|
|
301
|
+
assert total_nbytes % nbatches == 0
|
|
302
|
+
|
|
303
|
+
batch_nbytes = total_nbytes // nbatches
|
|
304
|
+
if batch_nbytes > max_io_nbytes:
|
|
305
|
+
assert size == nbatches
|
|
306
|
+
warnings.warn(f'batch_nbytes = {batch_nbytes} > max_io_nbytes = {max_io_nbytes}')
|
|
307
|
+
|
|
308
|
+
def loop(_, args):
|
|
309
|
+
args = move_axes_in(in_axes, args)
|
|
310
|
+
result = func(*args)
|
|
311
|
+
result = move_axes_out(out_axes, result)
|
|
312
|
+
return None, result
|
|
313
|
+
|
|
314
|
+
args = move_axes_out(in_axes, args)
|
|
315
|
+
args = batch(args, nbatches)
|
|
316
|
+
_, result = lax.scan(loop, None, args)
|
|
317
|
+
result = unbatch(result)
|
|
318
|
+
result = move_axes_in(out_axes, result)
|
|
319
|
+
|
|
320
|
+
check_same(example_result, result)
|
|
321
|
+
|
|
322
|
+
if return_nbatches:
|
|
323
|
+
return result, nbatches
|
|
324
|
+
return result
|
|
325
|
+
|
|
326
|
+
return batched_func
|
|
327
|
+
|
|
328
|
+
@tree_util.register_pytree_node_class
|
|
329
|
+
class LeafDict(dict):
|
|
330
|
+
""" dictionary that acts as a leaf in jax pytrees, to store compile-time
|
|
331
|
+
values """
|
|
332
|
+
|
|
333
|
+
def tree_flatten(self):
|
|
334
|
+
return (), self
|
|
335
|
+
|
|
336
|
+
@classmethod
|
|
337
|
+
def tree_unflatten(cls, aux_data, children):
|
|
338
|
+
return aux_data
|
|
339
|
+
|
|
340
|
+
def __repr__(self):
|
|
341
|
+
return f'{__class__.__name__}({super().__repr__()})'
|
bartz/mcmcloop.py
CHANGED
|
@@ -52,7 +52,7 @@ def run_mcmc(bart, n_burn, n_save, n_skip, callback, key):
|
|
|
52
52
|
n_save : int
|
|
53
53
|
The number of iterations to save.
|
|
54
54
|
n_skip : int
|
|
55
|
-
The number of iterations to skip between each saved iteration.
|
|
55
|
+
The number of iterations to skip between each saved iteration, plus 1.
|
|
56
56
|
callback : callable
|
|
57
57
|
An arbitrary function run at each iteration, called with the following
|
|
58
58
|
arguments, passed by keyword:
|
|
@@ -105,16 +105,19 @@ def run_mcmc(bart, n_burn, n_save, n_skip, callback, key):
|
|
|
105
105
|
output = {key: bart[key] for key in tracelist}
|
|
106
106
|
return (bart, i_total + 1, i_skip + 1, key), output
|
|
107
107
|
|
|
108
|
+
def empty_trace(bart, tracelist):
|
|
109
|
+
return {
|
|
110
|
+
key: jnp.empty((0,) + bart[key].shape, bart[key].dtype)
|
|
111
|
+
for key in tracelist
|
|
112
|
+
}
|
|
113
|
+
|
|
108
114
|
if n_burn > 0:
|
|
109
115
|
carry = bart, 0, 0, key
|
|
110
116
|
burnin_loop = functools.partial(inner_loop, tracelist=tracelist_burnin, burnin=True)
|
|
111
117
|
(bart, i_total, _, key), burnin_trace = lax.scan(burnin_loop, carry, None, n_burn)
|
|
112
118
|
else:
|
|
113
119
|
i_total = 0
|
|
114
|
-
burnin_trace =
|
|
115
|
-
key: jnp.empty((0,) + bart[key].shape, bart[key].dtype)
|
|
116
|
-
for key in tracelist_burnin
|
|
117
|
-
}
|
|
120
|
+
burnin_trace = empty_trace(bart, tracelist_burnin)
|
|
118
121
|
|
|
119
122
|
def outer_loop(carry, _):
|
|
120
123
|
bart, i_total, key = carry
|
|
@@ -124,8 +127,11 @@ def run_mcmc(bart, n_burn, n_save, n_skip, callback, key):
|
|
|
124
127
|
output = {key: bart[key] for key in tracelist_main}
|
|
125
128
|
return (bart, i_total, key), output
|
|
126
129
|
|
|
127
|
-
|
|
128
|
-
|
|
130
|
+
if n_save > 0:
|
|
131
|
+
carry = bart, i_total, key
|
|
132
|
+
(bart, _, _), main_trace = lax.scan(outer_loop, carry, None, n_save)
|
|
133
|
+
else:
|
|
134
|
+
main_trace = empty_trace(bart, tracelist_main)
|
|
129
135
|
|
|
130
136
|
return bart, burnin_trace, main_trace
|
|
131
137
|
|
|
@@ -133,7 +139,8 @@ def run_mcmc(bart, n_burn, n_save, n_skip, callback, key):
|
|
|
133
139
|
|
|
134
140
|
@functools.lru_cache
|
|
135
141
|
# cache to make the callback function object unique, such that the jit
|
|
136
|
-
# of run_mcmc recognizes it
|
|
142
|
+
# of run_mcmc recognizes it => with the callback state, I can make
|
|
143
|
+
# printevery a runtime quantity
|
|
137
144
|
def make_simple_print_callback(printevery):
|
|
138
145
|
"""
|
|
139
146
|
Create a logging callback function for MCMC iterations.
|
|
@@ -155,11 +162,12 @@ def make_simple_print_callback(printevery):
|
|
|
155
162
|
grow_acc = bart['grow_acc_count'] / bart['grow_prop_count']
|
|
156
163
|
prune_acc = bart['prune_acc_count'] / bart['prune_prop_count']
|
|
157
164
|
n_total = n_burn + n_save * n_skip
|
|
158
|
-
|
|
165
|
+
printcond = (i_total + 1) % printevery == 0
|
|
166
|
+
debug.callback(_simple_print_callback, burnin, i_total, n_total, grow_prop, grow_acc, prune_prop, prune_acc, printcond)
|
|
159
167
|
return callback
|
|
160
168
|
|
|
161
|
-
def
|
|
162
|
-
if
|
|
169
|
+
def _simple_print_callback(burnin, i_total, n_total, grow_prop, grow_acc, prune_prop, prune_acc, printcond):
|
|
170
|
+
if printcond:
|
|
163
171
|
burnin_flag = ' (burnin)' if burnin else ''
|
|
164
172
|
total_str = str(n_total)
|
|
165
173
|
ndigits = len(total_str)
|
bartz/mcmcstep.py
CHANGED
|
@@ -34,7 +34,6 @@ range of possible values.
|
|
|
34
34
|
"""
|
|
35
35
|
|
|
36
36
|
import functools
|
|
37
|
-
import math
|
|
38
37
|
|
|
39
38
|
import jax
|
|
40
39
|
from jax import random
|
|
@@ -55,6 +54,7 @@ def init(*,
|
|
|
55
54
|
small_float=jnp.float32,
|
|
56
55
|
large_float=jnp.float32,
|
|
57
56
|
min_points_per_leaf=None,
|
|
57
|
+
suffstat_batch_size='auto',
|
|
58
58
|
):
|
|
59
59
|
"""
|
|
60
60
|
Make a BART posterior sampling MCMC initial state.
|
|
@@ -82,6 +82,9 @@ def init(*,
|
|
|
82
82
|
The dtype for scalars, small arrays, and arrays which require accuracy.
|
|
83
83
|
min_points_per_leaf : int, optional
|
|
84
84
|
The minimum number of data points in a leaf node. 0 if not specified.
|
|
85
|
+
suffstat_batch_size : int, None, str, default 'auto'
|
|
86
|
+
The batch size for computing sufficient statistics. `None` for no
|
|
87
|
+
batching. If 'auto', pick a value based on the device of `y`.
|
|
85
88
|
|
|
86
89
|
Returns
|
|
87
90
|
-------
|
|
@@ -104,8 +107,9 @@ def init(*,
|
|
|
104
107
|
The number of grow/prune proposals made during one full MCMC cycle.
|
|
105
108
|
'grow_acc_count', 'prune_acc_count' : int
|
|
106
109
|
The number of grow/prune moves accepted during one full MCMC cycle.
|
|
107
|
-
'p_nonterminal' : large_float array (d
|
|
108
|
-
The probability of a nonterminal node at each depth
|
|
110
|
+
'p_nonterminal' : large_float array (d,)
|
|
111
|
+
The probability of a nonterminal node at each depth, padded with a
|
|
112
|
+
zero.
|
|
109
113
|
'sigma2_alpha' : large_float
|
|
110
114
|
The shape parameter of the inverse gamma prior on the noise variance.
|
|
111
115
|
'sigma2_beta' : large_float
|
|
@@ -121,18 +125,36 @@ def init(*,
|
|
|
121
125
|
'affluence_trees' : bool array (num_trees, 2 ** (d - 1)) or None
|
|
122
126
|
Whether a non-bottom leaf nodes contains twice `min_points_per_leaf`
|
|
123
127
|
datapoints. If `min_points_per_leaf` is not specified, this is None.
|
|
128
|
+
'opt' : LeafDict
|
|
129
|
+
A dictionary with config values:
|
|
130
|
+
|
|
131
|
+
'suffstat_batch_size' : int or None
|
|
132
|
+
The batch size for computing sufficient statistics.
|
|
133
|
+
'small_float' : dtype
|
|
134
|
+
The dtype for large arrays used in the algorithm.
|
|
135
|
+
'large_float' : dtype
|
|
136
|
+
The dtype for scalars, small arrays, and arrays which require
|
|
137
|
+
accuracy.
|
|
138
|
+
'require_min_points' : bool
|
|
139
|
+
Whether the `min_points_per_leaf` parameter is specified.
|
|
124
140
|
"""
|
|
125
141
|
|
|
126
142
|
p_nonterminal = jnp.asarray(p_nonterminal, large_float)
|
|
127
|
-
|
|
143
|
+
p_nonterminal = jnp.pad(p_nonterminal, (0, 1))
|
|
144
|
+
max_depth = p_nonterminal.size
|
|
128
145
|
|
|
129
146
|
@functools.partial(jax.vmap, in_axes=None, out_axes=0, axis_size=num_trees)
|
|
130
147
|
def make_forest(max_depth, dtype):
|
|
131
148
|
return grove.make_tree(max_depth, dtype)
|
|
132
149
|
|
|
150
|
+
small_float = jnp.dtype(small_float)
|
|
151
|
+
large_float = jnp.dtype(large_float)
|
|
152
|
+
y = jnp.asarray(y, small_float)
|
|
153
|
+
suffstat_batch_size = _choose_suffstat_batch_size(suffstat_batch_size, y)
|
|
154
|
+
|
|
133
155
|
bart = dict(
|
|
134
156
|
leaf_trees=make_forest(max_depth, small_float),
|
|
135
|
-
var_trees=make_forest(max_depth - 1,
|
|
157
|
+
var_trees=make_forest(max_depth - 1, jaxext.minimal_unsigned_dtype(X.shape[0] - 1)),
|
|
136
158
|
split_trees=make_forest(max_depth - 1, max_split.dtype),
|
|
137
159
|
resid=jnp.asarray(y, large_float),
|
|
138
160
|
sigma2=jnp.ones((), large_float),
|
|
@@ -143,9 +165,9 @@ def init(*,
|
|
|
143
165
|
p_nonterminal=p_nonterminal,
|
|
144
166
|
sigma2_alpha=jnp.asarray(sigma2_alpha, large_float),
|
|
145
167
|
sigma2_beta=jnp.asarray(sigma2_beta, large_float),
|
|
146
|
-
max_split=max_split,
|
|
147
|
-
y=
|
|
148
|
-
X=X,
|
|
168
|
+
max_split=jnp.asarray(max_split),
|
|
169
|
+
y=y,
|
|
170
|
+
X=jnp.asarray(X),
|
|
149
171
|
min_points_per_leaf=(
|
|
150
172
|
None if min_points_per_leaf is None else
|
|
151
173
|
jnp.asarray(min_points_per_leaf)
|
|
@@ -154,10 +176,32 @@ def init(*,
|
|
|
154
176
|
None if min_points_per_leaf is None else
|
|
155
177
|
make_forest(max_depth - 1, bool).at[:, 1].set(y.size >= 2 * min_points_per_leaf)
|
|
156
178
|
),
|
|
179
|
+
opt=jaxext.LeafDict(
|
|
180
|
+
suffstat_batch_size=suffstat_batch_size,
|
|
181
|
+
small_float=small_float,
|
|
182
|
+
large_float=large_float,
|
|
183
|
+
require_min_points=min_points_per_leaf is not None,
|
|
184
|
+
),
|
|
157
185
|
)
|
|
158
186
|
|
|
159
187
|
return bart
|
|
160
188
|
|
|
189
|
+
def _choose_suffstat_batch_size(size, y):
|
|
190
|
+
if size == 'auto':
|
|
191
|
+
platform = y.devices().pop().platform
|
|
192
|
+
if platform == 'cpu':
|
|
193
|
+
return None
|
|
194
|
+
# maybe I should batch residuals (not counts) for numerical
|
|
195
|
+
# accuracy, even if it's slower
|
|
196
|
+
elif platform == 'gpu':
|
|
197
|
+
return 128 # 128 is good on A100, and V100 at high n
|
|
198
|
+
# 512 is good on T4, and V100 at low n
|
|
199
|
+
else:
|
|
200
|
+
raise KeyError(f'Unknown platform: {platform}')
|
|
201
|
+
elif size is not None:
|
|
202
|
+
return int(size)
|
|
203
|
+
return size
|
|
204
|
+
|
|
161
205
|
def step(bart, key):
|
|
162
206
|
"""
|
|
163
207
|
Perform one full MCMC step on a BART state.
|
|
@@ -196,11 +240,14 @@ def sample_trees(bart, key):
|
|
|
196
240
|
|
|
197
241
|
Notes
|
|
198
242
|
-----
|
|
199
|
-
This function zeroes the proposal counters.
|
|
243
|
+
This function zeroes the proposal counters before using them.
|
|
200
244
|
"""
|
|
245
|
+
bart = bart.copy()
|
|
201
246
|
key, subkey = random.split(key)
|
|
202
247
|
grow_moves, prune_moves = sample_moves(bart, subkey)
|
|
203
|
-
|
|
248
|
+
bart['var_trees'] = grow_moves['var_tree']
|
|
249
|
+
grow_leaf_indices = grove.traverse_forest(bart['X'], grow_moves['var_tree'], grow_moves['split_tree'])
|
|
250
|
+
return accept_moves_and_sample_leaves(bart, grow_moves, prune_moves, grow_leaf_indices, key)
|
|
204
251
|
|
|
205
252
|
def sample_moves(bart, key):
|
|
206
253
|
"""
|
|
@@ -216,20 +263,7 @@ def sample_moves(bart, key):
|
|
|
216
263
|
Returns
|
|
217
264
|
-------
|
|
218
265
|
grow_moves, prune_moves : dict
|
|
219
|
-
The proposals for grow and prune moves
|
|
220
|
-
|
|
221
|
-
'allowed' : bool array (num_trees,)
|
|
222
|
-
Whether the move is possible.
|
|
223
|
-
'node' : int array (num_trees,)
|
|
224
|
-
The index of the leaf to grow or node to prune.
|
|
225
|
-
'var_tree' : int array (num_trees, 2 ** (d - 1),)
|
|
226
|
-
The new decision axes of the tree.
|
|
227
|
-
'split_tree' : int array (num_trees, 2 ** (d - 1),)
|
|
228
|
-
The new decision boundaries of the tree.
|
|
229
|
-
'partial_ratio' : float array (num_trees,)
|
|
230
|
-
A factor of the Metropolis-Hastings ratio of the move. It lacks
|
|
231
|
-
the likelihood ratio, and the probability of proposing the prune
|
|
232
|
-
move. For the prune move, the ratio is inverted.
|
|
266
|
+
The proposals for grow and prune moves. See `grow_move` and `prune_move`.
|
|
233
267
|
"""
|
|
234
268
|
key = random.split(key, bart['var_trees'].shape[0])
|
|
235
269
|
return sample_moves_vmap_trees(bart['var_trees'], bart['split_trees'], bart['affluence_trees'], bart['max_split'], bart['p_nonterminal'], key)
|
|
@@ -260,7 +294,7 @@ def grow_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, ke
|
|
|
260
294
|
Whether a leaf has enough points to be grown.
|
|
261
295
|
max_split : array (p,)
|
|
262
296
|
The maximum split index for each variable.
|
|
263
|
-
p_nonterminal : array (d
|
|
297
|
+
p_nonterminal : array (d,)
|
|
264
298
|
The probability of a nonterminal node at each depth.
|
|
265
299
|
key : jax.dtypes.prng_key array
|
|
266
300
|
A jax random key.
|
|
@@ -292,16 +326,16 @@ def grow_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, ke
|
|
|
292
326
|
var_tree = var_tree.at[leaf_to_grow].set(var.astype(var_tree.dtype))
|
|
293
327
|
|
|
294
328
|
split = choose_split(var_tree, split_tree, max_split, leaf_to_grow, key2)
|
|
295
|
-
|
|
329
|
+
split_tree = split_tree.at[leaf_to_grow].set(split.astype(split_tree.dtype))
|
|
296
330
|
|
|
297
|
-
ratio = compute_partial_ratio(num_growable, num_prunable, p_nonterminal, leaf_to_grow, split_tree
|
|
331
|
+
ratio = compute_partial_ratio(num_growable, num_prunable, p_nonterminal, leaf_to_grow, split_tree)
|
|
298
332
|
|
|
299
333
|
return dict(
|
|
300
334
|
allowed=allowed,
|
|
301
335
|
node=leaf_to_grow,
|
|
302
|
-
var_tree=var_tree,
|
|
303
|
-
split_tree=new_split_tree,
|
|
304
336
|
partial_ratio=ratio,
|
|
337
|
+
var_tree=var_tree,
|
|
338
|
+
split_tree=split_tree,
|
|
305
339
|
)
|
|
306
340
|
|
|
307
341
|
def choose_leaf(split_tree, affluence_tree, key):
|
|
@@ -464,7 +498,7 @@ def ancestor_variables(var_tree, max_split, node_index):
|
|
|
464
498
|
the parent. Unused spots are filled with `p`.
|
|
465
499
|
"""
|
|
466
500
|
max_num_ancestors = grove.tree_depth(var_tree) - 1
|
|
467
|
-
ancestor_vars = jnp.zeros(max_num_ancestors,
|
|
501
|
+
ancestor_vars = jnp.zeros(max_num_ancestors, jaxext.minimal_unsigned_dtype(max_split.size))
|
|
468
502
|
carry = ancestor_vars.size - 1, node_index, ancestor_vars
|
|
469
503
|
def loop(carry, _):
|
|
470
504
|
i, index, ancestor_vars = carry
|
|
@@ -569,7 +603,7 @@ def choose_split(var_tree, split_tree, max_split, leaf_index, key):
|
|
|
569
603
|
l, r = split_range(var_tree, split_tree, max_split, leaf_index, var)
|
|
570
604
|
return random.randint(key, (), l, r)
|
|
571
605
|
|
|
572
|
-
def compute_partial_ratio(num_growable, num_prunable, p_nonterminal, leaf_to_grow,
|
|
606
|
+
def compute_partial_ratio(num_growable, num_prunable, p_nonterminal, leaf_to_grow, new_split_tree):
|
|
573
607
|
"""
|
|
574
608
|
Compute the product of the transition and prior ratios of a grow move.
|
|
575
609
|
|
|
@@ -580,12 +614,10 @@ def compute_partial_ratio(num_growable, num_prunable, p_nonterminal, leaf_to_gro
|
|
|
580
614
|
num_prunable : int
|
|
581
615
|
The number of leaf parents that could be pruned, after converting the
|
|
582
616
|
leaf to be grown to a non-terminal node.
|
|
583
|
-
p_nonterminal : array (d
|
|
617
|
+
p_nonterminal : array (d,)
|
|
584
618
|
The probability of a nonterminal node at each depth.
|
|
585
619
|
leaf_to_grow : int
|
|
586
620
|
The index of the leaf to grow.
|
|
587
|
-
initial_split_tree : array (2 ** (d - 1),)
|
|
588
|
-
The splitting points of the tree, before the leaf is grown.
|
|
589
621
|
new_split_tree : array (2 ** (d - 1),)
|
|
590
622
|
The splitting points of the tree, after the leaf is grown.
|
|
591
623
|
|
|
@@ -600,14 +632,18 @@ def compute_partial_ratio(num_growable, num_prunable, p_nonterminal, leaf_to_gro
|
|
|
600
632
|
# the two ratios also contain factors num_available_split *
|
|
601
633
|
# num_available_var, but they cancel out
|
|
602
634
|
|
|
603
|
-
|
|
604
|
-
|
|
635
|
+
prune_allowed = leaf_to_grow != 1
|
|
636
|
+
# prune allowed <---> the initial tree is not a root
|
|
637
|
+
# leaf to grow is root --> the tree can only be a root
|
|
638
|
+
# tree is a root --> the only leaf I can grow is root
|
|
639
|
+
|
|
640
|
+
p_grow = jnp.where(prune_allowed, 0.5, 1)
|
|
605
641
|
|
|
606
642
|
trans_ratio = num_growable / (p_grow * num_prunable)
|
|
607
643
|
|
|
608
|
-
depth = grove.tree_depths(
|
|
644
|
+
depth = grove.tree_depths(new_split_tree.size)[leaf_to_grow]
|
|
609
645
|
p_parent = p_nonterminal[depth]
|
|
610
|
-
cp_children = 1 - p_nonterminal
|
|
646
|
+
cp_children = 1 - p_nonterminal[depth + 1]
|
|
611
647
|
tree_ratio = cp_children * cp_children * p_parent / (1 - p_parent)
|
|
612
648
|
|
|
613
649
|
return trans_ratio * tree_ratio
|
|
@@ -626,7 +662,7 @@ def prune_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, k
|
|
|
626
662
|
Whether a leaf has enough points to be grown.
|
|
627
663
|
max_split : array (p,)
|
|
628
664
|
The maximum split index for each variable.
|
|
629
|
-
p_nonterminal : array (d
|
|
665
|
+
p_nonterminal : array (d,)
|
|
630
666
|
The probability of a nonterminal node at each depth.
|
|
631
667
|
key : jax.dtypes.prng_key array
|
|
632
668
|
A jax random key.
|
|
@@ -639,28 +675,20 @@ def prune_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, k
|
|
|
639
675
|
'allowed' : bool
|
|
640
676
|
Whether the move is possible.
|
|
641
677
|
'node' : int
|
|
642
|
-
The index of the
|
|
643
|
-
'var_tree' : array (2 ** (d - 1),)
|
|
644
|
-
The new decision axes of the tree.
|
|
645
|
-
'split_tree' : array (2 ** (d - 1),)
|
|
646
|
-
The new decision boundaries of the tree.
|
|
678
|
+
The index of the node to prune.
|
|
647
679
|
'partial_ratio' : float
|
|
648
680
|
A factor of the Metropolis-Hastings ratio of the move. It lacks
|
|
649
681
|
the likelihood ratio and the probability of proposing the prune
|
|
650
682
|
move. This ratio is inverted.
|
|
651
683
|
"""
|
|
652
684
|
node_to_prune, num_prunable, num_growable = choose_leaf_parent(split_tree, affluence_tree, key)
|
|
653
|
-
allowed =
|
|
685
|
+
allowed = split_tree[1].astype(bool) # allowed iff the tree is not a root
|
|
654
686
|
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
ratio = compute_partial_ratio(num_growable, num_prunable, p_nonterminal, node_to_prune, new_split_tree, split_tree)
|
|
687
|
+
ratio = compute_partial_ratio(num_growable, num_prunable, p_nonterminal, node_to_prune, split_tree)
|
|
658
688
|
|
|
659
689
|
return dict(
|
|
660
690
|
allowed=allowed,
|
|
661
691
|
node=node_to_prune,
|
|
662
|
-
var_tree=var_tree,
|
|
663
|
-
split_tree=new_split_tree,
|
|
664
692
|
partial_ratio=ratio, # it is inverted in accept_move_and_sample_leaves
|
|
665
693
|
)
|
|
666
694
|
|
|
@@ -702,29 +730,37 @@ def choose_leaf_parent(split_tree, affluence_tree, key):
|
|
|
702
730
|
|
|
703
731
|
return node_to_prune, num_prunable, num_growable
|
|
704
732
|
|
|
705
|
-
def
|
|
733
|
+
def accept_moves_and_sample_leaves(bart, grow_moves, prune_moves, grow_leaf_indices, key):
|
|
706
734
|
"""
|
|
707
|
-
|
|
735
|
+
Accept or reject the proposed moves and sample the new leaf values.
|
|
708
736
|
|
|
709
737
|
Parameters
|
|
710
738
|
----------
|
|
711
|
-
|
|
712
|
-
|
|
739
|
+
bart : dict
|
|
740
|
+
A BART mcmc state.
|
|
741
|
+
grow_moves : dict
|
|
742
|
+
The proposals for grow moves, batched over the first axis. See
|
|
743
|
+
`grow_move`.
|
|
744
|
+
prune_moves : dict
|
|
745
|
+
The proposals for prune moves, batched over the first axis. See
|
|
746
|
+
`prune_move`.
|
|
747
|
+
grow_leaf_indices : int array (num_trees, n)
|
|
748
|
+
The leaf indices of the trees proposed by the grow move.
|
|
749
|
+
key : jax.dtypes.prng_key array
|
|
750
|
+
A jax random key.
|
|
713
751
|
|
|
714
752
|
Returns
|
|
715
753
|
-------
|
|
716
|
-
|
|
717
|
-
|
|
754
|
+
bart : dict
|
|
755
|
+
The new BART mcmc state.
|
|
718
756
|
"""
|
|
719
|
-
return split_tree.at[1].get(mode='fill', fill_value=0).astype(bool)
|
|
720
|
-
|
|
721
|
-
def accept_moves_and_sample_leaves(bart, grow_moves, prune_moves, key):
|
|
722
757
|
bart = bart.copy()
|
|
723
758
|
def loop(carry, item):
|
|
724
759
|
resid = carry.pop('resid')
|
|
725
760
|
resid, carry, trees = accept_move_and_sample_leaves(
|
|
726
761
|
bart['X'],
|
|
727
762
|
len(bart['leaf_trees']),
|
|
763
|
+
bart['opt']['suffstat_batch_size'],
|
|
728
764
|
resid,
|
|
729
765
|
bart['sigma2'],
|
|
730
766
|
bart['min_points_per_leaf'],
|
|
@@ -740,11 +776,11 @@ def accept_moves_and_sample_leaves(bart, grow_moves, prune_moves, key):
|
|
|
740
776
|
carry['resid'] = bart['resid']
|
|
741
777
|
items = (
|
|
742
778
|
bart['leaf_trees'],
|
|
743
|
-
bart['var_trees'],
|
|
744
779
|
bart['split_trees'],
|
|
745
780
|
bart['affluence_trees'],
|
|
746
781
|
grow_moves,
|
|
747
782
|
prune_moves,
|
|
783
|
+
grow_leaf_indices,
|
|
748
784
|
random.split(key, len(bart['leaf_trees'])),
|
|
749
785
|
)
|
|
750
786
|
carry, trees = lax.scan(loop, carry, items)
|
|
@@ -752,11 +788,50 @@ def accept_moves_and_sample_leaves(bart, grow_moves, prune_moves, key):
|
|
|
752
788
|
bart.update(trees)
|
|
753
789
|
return bart
|
|
754
790
|
|
|
755
|
-
def accept_move_and_sample_leaves(X, ntree, resid, sigma2, min_points_per_leaf, counts, leaf_tree,
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
|
|
791
|
+
def accept_move_and_sample_leaves(X, ntree, suffstat_batch_size, resid, sigma2, min_points_per_leaf, counts, leaf_tree, split_tree, affluence_tree, grow_move, prune_move, grow_leaf_indices, key):
|
|
792
|
+
"""
|
|
793
|
+
Accept or reject a proposed move and sample the new leaf values.
|
|
794
|
+
|
|
795
|
+
Parameters
|
|
796
|
+
----------
|
|
797
|
+
X : int array (p, n)
|
|
798
|
+
The predictors.
|
|
799
|
+
ntree : int
|
|
800
|
+
The number of trees in the forest.
|
|
801
|
+
suffstat_batch_size : int, None
|
|
802
|
+
The batch size for computing sufficient statistics.
|
|
803
|
+
resid : float array (n,)
|
|
804
|
+
The residuals (data minus forest value).
|
|
805
|
+
sigma2 : float
|
|
806
|
+
The noise variance.
|
|
807
|
+
min_points_per_leaf : int or None
|
|
808
|
+
The minimum number of data points in a leaf node.
|
|
809
|
+
counts : dict
|
|
810
|
+
The acceptance counts from the mcmc state dict.
|
|
811
|
+
leaf_tree : float array (2 ** d,)
|
|
812
|
+
The leaf values of the tree.
|
|
813
|
+
split_tree : int array (2 ** (d - 1),)
|
|
814
|
+
The decision boundaries of the tree.
|
|
815
|
+
affluence_tree : bool array (2 ** (d - 1),) or None
|
|
816
|
+
Whether a leaf has enough points to be grown.
|
|
817
|
+
grow_move : dict
|
|
818
|
+
The proposal for the grow move. See `grow_move`.
|
|
819
|
+
prune_move : dict
|
|
820
|
+
The proposal for the prune move. See `prune_move`.
|
|
821
|
+
grow_leaf_indices : int array (n,)
|
|
822
|
+
The leaf indices of the tree proposed by the grow move.
|
|
823
|
+
key : jax.dtypes.prng_key array
|
|
824
|
+
A jax random key.
|
|
825
|
+
|
|
826
|
+
Returns
|
|
827
|
+
-------
|
|
828
|
+
resid : float array (n,)
|
|
829
|
+
The updated residuals (data minus forest value).
|
|
830
|
+
counts : dict
|
|
831
|
+
The updated acceptance counts.
|
|
832
|
+
trees : dict
|
|
833
|
+
The updated tree arrays.
|
|
834
|
+
"""
|
|
760
835
|
|
|
761
836
|
# compute leaf indices in starting tree
|
|
762
837
|
grow_node = grow_move['node']
|
|
@@ -782,10 +857,7 @@ def accept_move_and_sample_leaves(X, ntree, resid, sigma2, min_points_per_leaf,
|
|
|
782
857
|
resid += leaf_tree[leaf_indices]
|
|
783
858
|
|
|
784
859
|
# aggregate residuals and count units per leaf
|
|
785
|
-
grow_resid_tree =
|
|
786
|
-
grow_resid_tree = grow_resid_tree.at[grow_leaf_indices].add(resid)
|
|
787
|
-
grow_count_tree = jnp.zeros_like(leaf_tree, grove.minimal_unsigned_dtype(resid.size))
|
|
788
|
-
grow_count_tree = grow_count_tree.at[grow_leaf_indices].add(1)
|
|
860
|
+
grow_resid_tree, grow_count_tree = sufficient_stat(resid, grow_leaf_indices, leaf_tree.size, suffstat_batch_size)
|
|
789
861
|
|
|
790
862
|
# compute aggregations in starting tree
|
|
791
863
|
# I do not zero the children because garbage there does not matter
|
|
@@ -833,10 +905,10 @@ def accept_move_and_sample_leaves(X, ntree, resid, sigma2, min_points_per_leaf,
|
|
|
833
905
|
|
|
834
906
|
# pick trees for chosen move
|
|
835
907
|
trees = {}
|
|
836
|
-
var_tree = jnp.where(do_grow, grow_move['var_tree'], var_tree)
|
|
837
908
|
split_tree = jnp.where(do_grow, grow_move['split_tree'], split_tree)
|
|
838
|
-
|
|
839
|
-
split_tree =
|
|
909
|
+
# the prune var tree is equal to the initial one, because I leave garbage values behind
|
|
910
|
+
split_tree = split_tree.at[prune_node].set(
|
|
911
|
+
jnp.where(do_prune, 0, split_tree[prune_node]))
|
|
840
912
|
if min_points_per_leaf is not None:
|
|
841
913
|
affluence_tree = jnp.where(do_grow, grow_affluence_tree, affluence_tree)
|
|
842
914
|
affluence_tree = jnp.where(do_prune, prune_affluence_tree, affluence_tree)
|
|
@@ -869,13 +941,60 @@ def accept_move_and_sample_leaves(X, ntree, resid, sigma2, min_points_per_leaf,
|
|
|
869
941
|
# pack trees
|
|
870
942
|
trees = {
|
|
871
943
|
'leaf_trees': leaf_tree,
|
|
872
|
-
'var_trees': var_tree,
|
|
873
944
|
'split_trees': split_tree,
|
|
874
945
|
'affluence_trees': affluence_tree,
|
|
875
946
|
}
|
|
876
947
|
|
|
877
948
|
return resid, counts, trees
|
|
878
949
|
|
|
950
|
+
def sufficient_stat(resid, leaf_indices, tree_size, batch_size):
|
|
951
|
+
"""
|
|
952
|
+
Compute the sufficient statistics for the likelihood ratio of a tree move.
|
|
953
|
+
|
|
954
|
+
Parameters
|
|
955
|
+
----------
|
|
956
|
+
resid : float array (n,)
|
|
957
|
+
The residuals (data minus forest value).
|
|
958
|
+
leaf_indices : int array (n,)
|
|
959
|
+
The leaf indices of the tree (in which leaf each data point falls into).
|
|
960
|
+
tree_size : int
|
|
961
|
+
The size of the tree array (2 ** d).
|
|
962
|
+
batch_size : int, None
|
|
963
|
+
The batch size for the aggregation. Batching increases numerical
|
|
964
|
+
accuracy and parallelism.
|
|
965
|
+
|
|
966
|
+
Returns
|
|
967
|
+
-------
|
|
968
|
+
resid_tree : float array (2 ** d,)
|
|
969
|
+
The sum of the residuals at data points in each leaf.
|
|
970
|
+
count_tree : int array (2 ** d,)
|
|
971
|
+
The number of data points in each leaf.
|
|
972
|
+
"""
|
|
973
|
+
if batch_size is None:
|
|
974
|
+
aggr_func = _aggregate_scatter
|
|
975
|
+
else:
|
|
976
|
+
aggr_func = functools.partial(_aggregate_batched, batch_size=batch_size)
|
|
977
|
+
resid_tree = aggr_func(resid, leaf_indices, tree_size, jnp.float32)
|
|
978
|
+
count_tree = aggr_func(1, leaf_indices, tree_size, jnp.uint32)
|
|
979
|
+
return resid_tree, count_tree
|
|
980
|
+
|
|
981
|
+
def _aggregate_scatter(values, indices, size, dtype):
|
|
982
|
+
return (jnp
|
|
983
|
+
.zeros(size, dtype)
|
|
984
|
+
.at[indices]
|
|
985
|
+
.add(values)
|
|
986
|
+
)
|
|
987
|
+
|
|
988
|
+
def _aggregate_batched(values, indices, size, dtype, batch_size):
|
|
989
|
+
nbatches = indices.size // batch_size + bool(indices.size % batch_size)
|
|
990
|
+
batch_indices = jnp.arange(indices.size) // batch_size
|
|
991
|
+
return (jnp
|
|
992
|
+
.zeros((nbatches, size), dtype)
|
|
993
|
+
.at[batch_indices, indices]
|
|
994
|
+
.add(values)
|
|
995
|
+
.sum(axis=0)
|
|
996
|
+
)
|
|
997
|
+
|
|
879
998
|
def compute_p_prune_back(new_split_tree, new_affluence_tree):
|
|
880
999
|
"""
|
|
881
1000
|
Compute the probability of proposing a prune move after doing a grow move.
|
bartz/prepcovars.py
CHANGED
|
@@ -27,8 +27,10 @@ import functools
|
|
|
27
27
|
import jax
|
|
28
28
|
from jax import numpy as jnp
|
|
29
29
|
|
|
30
|
+
from . import jaxext
|
|
30
31
|
from . import grove
|
|
31
32
|
|
|
33
|
+
@functools.partial(jax.jit, static_argnums=(1,))
|
|
32
34
|
def quantilized_splits_from_matrix(X, max_bins):
|
|
33
35
|
"""
|
|
34
36
|
Determine bins that make the distribution of each predictor uniform.
|
|
@@ -52,48 +54,41 @@ def quantilized_splits_from_matrix(X, max_bins):
|
|
|
52
54
|
The number of actually used values in each row of `splits`.
|
|
53
55
|
"""
|
|
54
56
|
out_length = min(max_bins, X.shape[1]) - 1
|
|
55
|
-
return
|
|
57
|
+
# return _quantilized_splits_from_matrix(X, out_length)
|
|
58
|
+
@functools.partial(jaxext.autobatch, max_io_nbytes=500_000_000)
|
|
59
|
+
def func(X):
|
|
60
|
+
return _quantilized_splits_from_matrix(X, out_length)
|
|
61
|
+
return func(X)
|
|
56
62
|
|
|
57
63
|
@functools.partial(jax.vmap, in_axes=(0, None))
|
|
58
|
-
def
|
|
59
|
-
huge = huge_value(x)
|
|
60
|
-
u =
|
|
61
|
-
actual_length
|
|
62
|
-
|
|
64
|
+
def _quantilized_splits_from_matrix(x, out_length):
|
|
65
|
+
huge = jaxext.huge_value(x)
|
|
66
|
+
u, actual_length = jaxext.unique(x, size=x.size, fill_value=huge)
|
|
67
|
+
actual_length -= 1
|
|
68
|
+
if jnp.issubdtype(x.dtype, jnp.integer):
|
|
69
|
+
midpoints = u[:-1] + jaxext.ensure_unsigned(u[1:] - u[:-1]) // 2
|
|
70
|
+
indices = jnp.arange(midpoints.size, dtype=jaxext.minimal_unsigned_dtype(midpoints.size - 1))
|
|
71
|
+
midpoints = jnp.where(indices < actual_length, midpoints, huge)
|
|
72
|
+
else:
|
|
73
|
+
midpoints = (u[1:] + u[:-1]) / 2
|
|
63
74
|
indices = jnp.linspace(-1, actual_length, out_length + 2)[1:-1]
|
|
64
|
-
indices = jnp.around(indices).astype(
|
|
75
|
+
indices = jnp.around(indices).astype(jaxext.minimal_unsigned_dtype(midpoints.size - 1))
|
|
65
76
|
# indices calculation with float rather than int to avoid potential
|
|
66
77
|
# overflow with int32, and to round to nearest instead of rounding down
|
|
67
78
|
decimated_midpoints = midpoints[indices]
|
|
68
79
|
truncated_midpoints = midpoints[:out_length]
|
|
69
80
|
splits = jnp.where(actual_length > out_length, decimated_midpoints, truncated_midpoints)
|
|
70
81
|
max_split = jnp.minimum(actual_length, out_length)
|
|
71
|
-
max_split = max_split.astype(
|
|
82
|
+
max_split = max_split.astype(jaxext.minimal_unsigned_dtype(out_length))
|
|
72
83
|
return splits, max_split
|
|
73
84
|
|
|
74
|
-
|
|
75
|
-
"""
|
|
76
|
-
Return the maximum value that can be stored in `x`.
|
|
77
|
-
|
|
78
|
-
Parameters
|
|
79
|
-
----------
|
|
80
|
-
x : array
|
|
81
|
-
A numerical numpy or jax array.
|
|
82
|
-
|
|
83
|
-
Returns
|
|
84
|
-
-------
|
|
85
|
-
maxval : scalar
|
|
86
|
-
The maximum value allowed by `x`'s type (+inf for floats).
|
|
87
|
-
"""
|
|
88
|
-
if jnp.issubdtype(x.dtype, jnp.integer):
|
|
89
|
-
return jnp.iinfo(x.dtype).max
|
|
90
|
-
else:
|
|
91
|
-
return jnp.inf
|
|
92
|
-
|
|
85
|
+
@jax.jit
|
|
93
86
|
def bin_predictors(X, splits):
|
|
94
87
|
"""
|
|
95
88
|
Bin the predictors according to the given splits.
|
|
96
89
|
|
|
90
|
+
A value ``x`` is mapped to bin ``i`` iff ``splits[i - 1] < x <= splits[i]``.
|
|
91
|
+
|
|
97
92
|
Parameters
|
|
98
93
|
----------
|
|
99
94
|
X : array (p, n)
|
|
@@ -110,9 +105,9 @@ def bin_predictors(X, splits):
|
|
|
110
105
|
A matrix with `p` predictors and `n` observations, where each predictor
|
|
111
106
|
has been replaced by the index of the bin it falls into.
|
|
112
107
|
"""
|
|
113
|
-
return
|
|
108
|
+
return _bin_predictors(X, splits)
|
|
114
109
|
|
|
115
110
|
@jax.vmap
|
|
116
|
-
def
|
|
117
|
-
dtype =
|
|
111
|
+
def _bin_predictors(x, splits):
|
|
112
|
+
dtype = jaxext.minimal_unsigned_dtype(splits.size)
|
|
118
113
|
return jnp.searchsorted(splits, x).astype(dtype)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: bartz
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.2.0
|
|
4
4
|
Summary: A JAX implementation of BART
|
|
5
5
|
Home-page: https://github.com/Gattocrucco/bartz
|
|
6
6
|
License: MIT
|
|
@@ -20,7 +20,13 @@ Project-URL: Bug Tracker, https://github.com/Gattocrucco/bartz/issues
|
|
|
20
20
|
Project-URL: Repository, https://github.com/Gattocrucco/bartz
|
|
21
21
|
Description-Content-Type: text/markdown
|
|
22
22
|
|
|
23
|
+
[](https://pypi.org/project/bartz/)
|
|
24
|
+
|
|
23
25
|
# BART vectoriZed
|
|
24
26
|
|
|
25
27
|
A branchless vectorized implementation of Bayesian Additive Regression Trees (BART) in JAX.
|
|
26
28
|
|
|
29
|
+
BART is a nonparametric Bayesian regression technique. Given predictors $X$ and responses $y$, BART finds a function to predict $y$ given $X$. The result of the inference is a sample of possible functions, representing the uncertainty over the determination of the function.
|
|
30
|
+
|
|
31
|
+
This Python module provides an implementation of BART that runs on GPU, to process large datasets faster. It is also a good on CPU. Most other implementations of BART are for R, and run on CPU only.
|
|
32
|
+
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
bartz/BART.py,sha256=pRG7mALenknX2JHqY-VyhO9-evDgEC6hWBp4jpecBdM,15801
|
|
2
|
+
bartz/__init__.py,sha256=E96vsP0bZ8brejpZmEmRoXuMsUdinO_B_SKUUl1rLsg,1448
|
|
3
|
+
bartz/_version.py,sha256=FVHPBGkfhbQDi_z3v0PiKJrXXqXOx0vGW_1VaqNJi7U,22
|
|
4
|
+
bartz/debug.py,sha256=9ZH-JfwZVu5OPhHBEyXQHAU5H9KIu1vxLK7yNv4m4Ew,5314
|
|
5
|
+
bartz/grove.py,sha256=Wj_7jHl9w3uwuVdH4hoeXowimGpdRE2lGIzr4aDkzsI,8291
|
|
6
|
+
bartz/jaxext.py,sha256=VYA41D5F7DYcAAVtkcZtEN927HxQGOOQM-uGsgr2CPc,10996
|
|
7
|
+
bartz/mcmcloop.py,sha256=lheLrjVxmlyQzc_92zeNsFhdkrhEWQEjoAWFbVzknnw,7701
|
|
8
|
+
bartz/mcmcstep.py,sha256=3ba94hXBW4UAZ11SFshnwJAgn6bpIqSZdRy_wQjEkrk,39278
|
|
9
|
+
bartz/prepcovars.py,sha256=iiQ0WjSj4--l5DgPW626Qg2SSB6ljnaaUsBz_A8kFrI,4634
|
|
10
|
+
bartz-0.2.0.dist-info/LICENSE,sha256=heuIJZQK9IexJYC-fYHoLUrgj8HG8yS3G072EvKh-94,1073
|
|
11
|
+
bartz-0.2.0.dist-info/METADATA,sha256=LiYjTAzgoxUM2MAuaKtf0VW-_zciTKBkTX5B7HNvUbI,1490
|
|
12
|
+
bartz-0.2.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
|
13
|
+
bartz-0.2.0.dist-info/RECORD,,
|
bartz-0.1.0.dist-info/RECORD
DELETED
|
@@ -1,13 +0,0 @@
|
|
|
1
|
-
bartz/__init__.py,sha256=40tX5XHoTiGnZcoeogVpyNOM_5rbHt-Y6zTI0NS7OA4,1345
|
|
2
|
-
bartz/_version.py,sha256=IMjkMO3twhQzluVTo8Z6rE7Eg-9U79_LGKMcsWLKBkY,22
|
|
3
|
-
bartz/debug.py,sha256=_HOjDieipAgliP6B6C0UMgz-mVgmeZ3zmtzVe-iMGtY,5289
|
|
4
|
-
bartz/grove.py,sha256=LHhnvNKLb-jxUf4YjP927Hf9txkXynhMZ2ejtMRWZl4,8353
|
|
5
|
-
bartz/interface.py,sha256=INyNuHzFySwXAsXofVZDpTsMv78AR_3VCvAHbZFh92c,15724
|
|
6
|
-
bartz/jaxext.py,sha256=FK5j1zfW1yR4-yPKcD7ZvKSkVQ5--jHjQpVCl4n4gXY,2844
|
|
7
|
-
bartz/mcmcloop.py,sha256=xTxC1AkNX8jCrMArblvlMjnjMh80q1M3a6ZGrDdfsFI,7423
|
|
8
|
-
bartz/mcmcstep.py,sha256=6zkpTqgIrapeVy9mhy6BlsIO0s26HwBRDfw_6dVMmZA,35207
|
|
9
|
-
bartz/prepcovars.py,sha256=3ddDOtNNop3Ba2Kgy_dZ6apFydtwaEXH3uXSmmKf9Fs,4421
|
|
10
|
-
bartz-0.1.0.dist-info/LICENSE,sha256=heuIJZQK9IexJYC-fYHoLUrgj8HG8yS3G072EvKh-94,1073
|
|
11
|
-
bartz-0.1.0.dist-info/METADATA,sha256=8YYlbCf7frDtT2of6tNlnBbuGqyO8YyYlED8OXSiBpA,933
|
|
12
|
-
bartz-0.1.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
|
13
|
-
bartz-0.1.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|