bartz 0.0__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bartz/__init__.py +1 -1
- bartz/_version.py +1 -0
- bartz/debug.py +9 -27
- bartz/grove.py +71 -118
- bartz/interface.py +29 -32
- bartz/mcmcloop.py +17 -8
- bartz/mcmcstep.py +379 -427
- {bartz-0.0.dist-info → bartz-0.1.0.dist-info}/METADATA +8 -7
- bartz-0.1.0.dist-info/RECORD +13 -0
- bartz-0.0.dist-info/RECORD +0 -12
- {bartz-0.0.dist-info → bartz-0.1.0.dist-info}/LICENSE +0 -0
- {bartz-0.0.dist-info → bartz-0.1.0.dist-info}/WHEEL +0 -0
bartz/__init__.py
CHANGED
bartz/_version.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = '0.1.0'
|
bartz/debug.py
CHANGED
|
@@ -7,22 +7,6 @@ from jax import lax
|
|
|
7
7
|
from . import grove
|
|
8
8
|
from . import mcmcstep
|
|
9
9
|
|
|
10
|
-
def trace_evaluate_trees(bart, X):
|
|
11
|
-
"""
|
|
12
|
-
Evaluate all trees, for all samples, at all x. Out axes:
|
|
13
|
-
0: mcmc sample
|
|
14
|
-
1: tree
|
|
15
|
-
2: X
|
|
16
|
-
"""
|
|
17
|
-
def loop(_, bart):
|
|
18
|
-
return None, evaluate_all_trees(X, bart['leaf_trees'], bart['var_trees'], bart['split_trees'])
|
|
19
|
-
_, y = lax.scan(loop, None, bart)
|
|
20
|
-
return y
|
|
21
|
-
|
|
22
|
-
@functools.partial(jax.vmap, in_axes=(None, 0, 0, 0)) # vectorize over forest
|
|
23
|
-
def evaluate_all_trees(X, leaf_trees, var_trees, split_trees):
|
|
24
|
-
return grove.evaluate_tree_vmap_x(X, leaf_trees, var_trees, split_trees, jnp.float32)
|
|
25
|
-
|
|
26
10
|
def print_tree(leaf_tree, var_tree, split_tree, print_all=False):
|
|
27
11
|
|
|
28
12
|
tee = '├──'
|
|
@@ -65,15 +49,11 @@ def print_tree(leaf_tree, var_tree, split_tree, print_all=False):
|
|
|
65
49
|
else:
|
|
66
50
|
link = ' '
|
|
67
51
|
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
number = str(index).rjust(ndigits)
|
|
72
|
-
number = f' {number} '
|
|
73
|
-
else:
|
|
74
|
-
number = ''
|
|
52
|
+
max_number = len(leaf_tree) - 1
|
|
53
|
+
ndigits = len(str(max_number))
|
|
54
|
+
number = str(index).rjust(ndigits)
|
|
75
55
|
|
|
76
|
-
print(f'{number}{indent}{first_indent}{link}{node_str}')
|
|
56
|
+
print(f' {number} {indent}{first_indent}{link}{node_str}')
|
|
77
57
|
|
|
78
58
|
indent += next_indent
|
|
79
59
|
unused = unused or is_leaf
|
|
@@ -101,8 +81,10 @@ def trace_depth_distr(split_trees_trace):
|
|
|
101
81
|
return jax.vmap(forest_depth_distr)(split_trees_trace)
|
|
102
82
|
|
|
103
83
|
def points_per_leaf_distr(var_tree, split_tree, X):
|
|
104
|
-
|
|
105
|
-
|
|
84
|
+
traverse_tree = jax.vmap(grove.traverse_tree, in_axes=(1, None, None))
|
|
85
|
+
indices = traverse_tree(X, var_tree, split_tree)
|
|
86
|
+
count_tree = jnp.zeros(2 * split_tree.size, dtype=grove.minimal_unsigned_dtype(indices.size))
|
|
87
|
+
count_tree = count_tree.at[indices].add(1)
|
|
106
88
|
is_leaf = grove.is_actual_leaf(split_tree, add_bottom_level=True).view(jnp.uint8)
|
|
107
89
|
return jnp.bincount(count_tree, is_leaf, length=X.shape[1] + 1)
|
|
108
90
|
|
|
@@ -129,7 +111,7 @@ def check_sizes(leaf_tree, var_tree, split_tree, max_split):
|
|
|
129
111
|
return leaf_tree.size == 2 * var_tree.size == 2 * split_tree.size
|
|
130
112
|
|
|
131
113
|
def check_unused_node(leaf_tree, var_tree, split_tree, max_split):
|
|
132
|
-
return (
|
|
114
|
+
return (var_tree[0] == 0) & (split_tree[0] == 0)
|
|
133
115
|
|
|
134
116
|
def check_leaf_values(leaf_tree, var_tree, split_tree, max_split):
|
|
135
117
|
return jnp.all(jnp.isfinite(leaf_tree))
|
bartz/grove.py
CHANGED
|
@@ -28,13 +28,15 @@ Functions to create and manipulate binary trees.
|
|
|
28
28
|
|
|
29
29
|
A tree is represented with arrays as a heap. The root node is at index 1. The children nodes of a node at index :math:`i` are at indices :math:`2i` (left child) and :math:`2i + 1` (right child). The array element at index 0 is unused.
|
|
30
30
|
|
|
31
|
-
A decision tree is represented by tree arrays: 'leaf', 'var', and 'split'.
|
|
31
|
+
A decision tree is represented by tree arrays: 'leaf', 'var', and 'split'.
|
|
32
32
|
|
|
33
|
-
|
|
33
|
+
The 'leaf' array contains the values in the leaves.
|
|
34
34
|
|
|
35
|
-
|
|
35
|
+
The 'var' array contains the axes along which the decision nodes operate.
|
|
36
|
+
|
|
37
|
+
The 'split' array contains the decision boundaries. The boundaries are open on the right, i.e., a point belongs to the left child iff x < split. Whether a node is a leaf is indicated by the corresponding 'split' element being 0.
|
|
36
38
|
|
|
37
|
-
|
|
39
|
+
Since the nodes at the bottom can only be leaves and not decision nodes, the 'var' and 'split' arrays have half the length of the 'leaf' array.
|
|
38
40
|
|
|
39
41
|
"""
|
|
40
42
|
|
|
@@ -42,6 +44,7 @@ import functools
|
|
|
42
44
|
import math
|
|
43
45
|
|
|
44
46
|
import jax
|
|
47
|
+
|
|
45
48
|
from jax import numpy as jnp
|
|
46
49
|
from jax import lax
|
|
47
50
|
|
|
@@ -63,24 +66,18 @@ def make_tree(depth, dtype):
|
|
|
63
66
|
-------
|
|
64
67
|
tree : array
|
|
65
68
|
An array of zeroes with shape (2 ** depth,).
|
|
66
|
-
|
|
67
|
-
Notes
|
|
68
|
-
-----
|
|
69
|
-
The tree is represented as a heap, with the root node at index 1, and the
|
|
70
|
-
children of the node at index i at indices 2 * i and 2 * i + 1. The element
|
|
71
|
-
at index 0 is unused.
|
|
72
69
|
"""
|
|
73
70
|
return jnp.zeros(2 ** depth, dtype)
|
|
74
71
|
|
|
75
72
|
def tree_depth(tree):
|
|
76
73
|
"""
|
|
77
|
-
Return the maximum depth of a
|
|
74
|
+
Return the maximum depth of a tree.
|
|
78
75
|
|
|
79
76
|
Parameters
|
|
80
77
|
----------
|
|
81
78
|
tree : array
|
|
82
|
-
A
|
|
83
|
-
|
|
79
|
+
A tree created by `make_tree`. If the array is ND, the tree structure is
|
|
80
|
+
assumed to be along the last axis.
|
|
84
81
|
|
|
85
82
|
Returns
|
|
86
83
|
-------
|
|
@@ -89,120 +86,97 @@ def tree_depth(tree):
|
|
|
89
86
|
"""
|
|
90
87
|
return int(round(math.log2(tree.shape[-1])))
|
|
91
88
|
|
|
92
|
-
def
|
|
89
|
+
def traverse_tree(x, var_tree, split_tree):
|
|
93
90
|
"""
|
|
94
|
-
|
|
91
|
+
Find the leaf where a point falls into.
|
|
95
92
|
|
|
96
93
|
Parameters
|
|
97
94
|
----------
|
|
98
|
-
|
|
95
|
+
x : array (p,)
|
|
99
96
|
The coordinates to evaluate the tree at.
|
|
100
|
-
|
|
101
|
-
The
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
The variable indices of the tree or forest. Each index is in [0, p) and
|
|
105
|
-
indicates which value of `X` to consider.
|
|
106
|
-
split_trees : array (n,) or (m, n)
|
|
107
|
-
The split values of the tree or forest. Leaf nodes are indicated by the
|
|
108
|
-
condition `split == 0`. If non-zero, the node has children, and its left
|
|
109
|
-
children is assigned points which satisfy `x < split`.
|
|
110
|
-
out_dtype : dtype
|
|
111
|
-
The dtype of the output.
|
|
97
|
+
var_tree : array (2 ** (d - 1),)
|
|
98
|
+
The decision axes of the tree.
|
|
99
|
+
split_tree : array (2 ** (d - 1),)
|
|
100
|
+
The decision boundaries of the tree.
|
|
112
101
|
|
|
113
102
|
Returns
|
|
114
103
|
-------
|
|
115
|
-
|
|
116
|
-
The
|
|
104
|
+
index : int
|
|
105
|
+
The index of the leaf.
|
|
117
106
|
"""
|
|
118
107
|
|
|
119
|
-
is_forest = leaf_trees.ndim == 2
|
|
120
|
-
if is_forest:
|
|
121
|
-
m, _ = leaf_trees.shape
|
|
122
|
-
forest_shape = m,
|
|
123
|
-
tree_index = jnp.arange(m, dtype=minimal_unsigned_dtype(m - 1)),
|
|
124
|
-
else:
|
|
125
|
-
forest_shape = ()
|
|
126
|
-
tree_index = ()
|
|
127
|
-
|
|
128
108
|
carry = (
|
|
129
|
-
jnp.zeros(
|
|
130
|
-
jnp.
|
|
131
|
-
jnp.ones(forest_shape, minimal_unsigned_dtype(leaf_trees.shape[-1] - 1))
|
|
109
|
+
jnp.zeros((), bool),
|
|
110
|
+
jnp.ones((), minimal_unsigned_dtype(2 * var_tree.size - 1)),
|
|
132
111
|
)
|
|
133
112
|
|
|
134
113
|
def loop(carry, _):
|
|
135
|
-
leaf_found,
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
if is_forest:
|
|
140
|
-
leaf_sum = jnp.sum(leaf_value, where=is_leaf) # TODO set dtype to large float
|
|
141
|
-
# alternative: dot(is_leaf, leaf_value):
|
|
142
|
-
# - maybe faster
|
|
143
|
-
# - maybe less accurate
|
|
144
|
-
# - fucked by nans
|
|
145
|
-
else:
|
|
146
|
-
leaf_sum = jnp.where(is_leaf, leaf_value, 0)
|
|
147
|
-
out += leaf_sum
|
|
148
|
-
leaf_found |= is_leaf
|
|
149
|
-
|
|
150
|
-
split = split_trees.at[tree_index + (node_index,)].get(mode='fill', fill_value=0)
|
|
151
|
-
var = var_trees.at[tree_index + (node_index,)].get(mode='fill', fill_value=0)
|
|
152
|
-
x = X[var]
|
|
114
|
+
leaf_found, index = carry
|
|
115
|
+
|
|
116
|
+
split = split_tree.at[index].get(mode='fill', fill_value=0)
|
|
117
|
+
var = var_tree.at[index].get(mode='fill', fill_value=0)
|
|
153
118
|
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
119
|
+
leaf_found |= split_tree.at[index].get(mode='fill', fill_value=0) == 0
|
|
120
|
+
child_index = (index << 1) + (x[var] >= split)
|
|
121
|
+
index = jnp.where(leaf_found, index, child_index)
|
|
157
122
|
|
|
158
|
-
|
|
159
|
-
return carry, _
|
|
123
|
+
return (leaf_found, index), None
|
|
160
124
|
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
125
|
+
# TODO
|
|
126
|
+
# - unroll (how much? 5?)
|
|
127
|
+
# - separate and special-case the last iteration
|
|
164
128
|
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
maximum value.
|
|
169
|
-
"""
|
|
170
|
-
if max_value < 2 ** 8:
|
|
171
|
-
return jnp.uint8
|
|
172
|
-
if max_value < 2 ** 16:
|
|
173
|
-
return jnp.uint16
|
|
174
|
-
if max_value < 2 ** 32:
|
|
175
|
-
return jnp.uint32
|
|
176
|
-
return jnp.uint64
|
|
129
|
+
depth = 1 + tree_depth(var_tree)
|
|
130
|
+
(_, index), _ = lax.scan(loop, carry, None, depth)
|
|
131
|
+
return index
|
|
177
132
|
|
|
178
|
-
|
|
179
|
-
def evaluate_tree_vmap_x(X, leaf_trees, var_trees, split_trees, out_dtype):
|
|
133
|
+
def evaluate_forest(X, leaf_trees, var_trees, split_trees, dtype):
|
|
180
134
|
"""
|
|
181
|
-
Evaluate a
|
|
135
|
+
Evaluate a ensemble of trees at an array of points.
|
|
182
136
|
|
|
183
137
|
Parameters
|
|
184
138
|
----------
|
|
185
139
|
X : array (p, n)
|
|
186
|
-
The
|
|
187
|
-
leaf_trees :
|
|
140
|
+
The coordinates to evaluate the trees at.
|
|
141
|
+
leaf_trees : (m, 2 ** d)
|
|
188
142
|
The leaf values of the tree or forest. If the input is a forest, the
|
|
189
143
|
first axis is the tree index, and the values are summed.
|
|
190
|
-
var_trees : array (
|
|
191
|
-
The
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
condition `split == 0`. If non-zero, the node has children, and its left
|
|
196
|
-
children is assigned points which satisfy `x < split`.
|
|
197
|
-
out_dtype : dtype
|
|
144
|
+
var_trees : array (m, 2 ** (d - 1))
|
|
145
|
+
The decision axes of the trees.
|
|
146
|
+
split_trees : array (m, 2 ** (d - 1))
|
|
147
|
+
The decision boundaries of the trees.
|
|
148
|
+
dtype : dtype
|
|
198
149
|
The dtype of the output.
|
|
199
150
|
|
|
200
151
|
Returns
|
|
201
152
|
-------
|
|
202
|
-
out : (n,)
|
|
203
|
-
The
|
|
153
|
+
out : array (n,)
|
|
154
|
+
The sum of the values of the trees at the points in `X`.
|
|
204
155
|
"""
|
|
205
|
-
|
|
156
|
+
indices = _traverse_forest(X, var_trees, split_trees)
|
|
157
|
+
ntree, _ = leaf_trees.shape
|
|
158
|
+
tree_index = jnp.arange(ntree, dtype=minimal_unsigned_dtype(ntree - 1))[:, None]
|
|
159
|
+
leaves = leaf_trees[tree_index, indices]
|
|
160
|
+
return jnp.sum(leaves, axis=0, dtype=dtype)
|
|
161
|
+
# this sum suggests to swap the vmaps, but I think it's better for X copying to keep it that way
|
|
162
|
+
|
|
163
|
+
@functools.partial(jax.vmap, in_axes=(None, 0, 0))
|
|
164
|
+
@functools.partial(jax.vmap, in_axes=(1, None, None))
|
|
165
|
+
def _traverse_forest(X, var_trees, split_trees):
|
|
166
|
+
return traverse_tree(X, var_trees, split_trees)
|
|
167
|
+
|
|
168
|
+
def minimal_unsigned_dtype(max_value):
|
|
169
|
+
"""
|
|
170
|
+
Return the smallest unsigned integer dtype that can represent a given
|
|
171
|
+
maximum value.
|
|
172
|
+
"""
|
|
173
|
+
if max_value < 2 ** 8:
|
|
174
|
+
return jnp.uint8
|
|
175
|
+
if max_value < 2 ** 16:
|
|
176
|
+
return jnp.uint16
|
|
177
|
+
if max_value < 2 ** 32:
|
|
178
|
+
return jnp.uint32
|
|
179
|
+
return jnp.uint64
|
|
206
180
|
|
|
207
181
|
def is_actual_leaf(split_tree, *, add_bottom_level=False):
|
|
208
182
|
"""
|
|
@@ -239,7 +213,7 @@ def is_leaves_parent(split_tree):
|
|
|
239
213
|
Parameters
|
|
240
214
|
----------
|
|
241
215
|
split_tree : int array (2 ** (d - 1),)
|
|
242
|
-
The
|
|
216
|
+
The decision boundaries of the tree.
|
|
243
217
|
|
|
244
218
|
Returns
|
|
245
219
|
-------
|
|
@@ -279,24 +253,3 @@ def tree_depths(tree_length):
|
|
|
279
253
|
depths.append(depth - 1)
|
|
280
254
|
depths[0] = 0
|
|
281
255
|
return jnp.array(depths, minimal_unsigned_dtype(max(depths)))
|
|
282
|
-
|
|
283
|
-
def index_depth(index, tree_length):
|
|
284
|
-
"""
|
|
285
|
-
Return the depth of a node in a binary tree.
|
|
286
|
-
|
|
287
|
-
Parameters
|
|
288
|
-
----------
|
|
289
|
-
index : int
|
|
290
|
-
The index of the node.
|
|
291
|
-
tree_length : int
|
|
292
|
-
The length of the tree array, i.e., 2 ** d.
|
|
293
|
-
|
|
294
|
-
Returns
|
|
295
|
-
-------
|
|
296
|
-
depth : int
|
|
297
|
-
The depth of the node. The root node (index 1) has depth 0. The depth is
|
|
298
|
-
the position of the most significant non-zero bit in the index. If
|
|
299
|
-
``index == 0``, return -1.
|
|
300
|
-
"""
|
|
301
|
-
depths = tree_depths(tree_length)
|
|
302
|
-
return depths[index]
|
bartz/interface.py
CHANGED
|
@@ -38,7 +38,7 @@ class BART:
|
|
|
38
38
|
Nonparametric regression with Bayesian Additive Regression Trees (BART).
|
|
39
39
|
|
|
40
40
|
Regress `y_train` on `x_train` with a latent mean function represented as
|
|
41
|
-
a sum of decision trees. The inference is carried out by
|
|
41
|
+
a sum of decision trees. The inference is carried out by sampling the
|
|
42
42
|
posterior distribution of the tree ensemble with an MCMC.
|
|
43
43
|
|
|
44
44
|
Parameters
|
|
@@ -86,7 +86,7 @@ class BART:
|
|
|
86
86
|
predictor is binned such that its distribution in `x_train` is
|
|
87
87
|
approximately uniform across bins. The number of bins is at most the
|
|
88
88
|
number of unique values appearing in `x_train`, or ``numcut + 1``.
|
|
89
|
-
Before running the algorithm, the predictors are compressed to
|
|
89
|
+
Before running the algorithm, the predictors are compressed to the
|
|
90
90
|
smallest integer type that fits the bin indices, so `numcut` is best set
|
|
91
91
|
to the maximum value of an unsigned integer type.
|
|
92
92
|
ndpost : int, default 1000
|
|
@@ -102,14 +102,6 @@ class BART:
|
|
|
102
102
|
|
|
103
103
|
Attributes
|
|
104
104
|
----------
|
|
105
|
-
offset : float
|
|
106
|
-
The prior mean of the latent mean function.
|
|
107
|
-
scale : float
|
|
108
|
-
The prior standard deviation of the latent mean function.
|
|
109
|
-
lamda : float
|
|
110
|
-
The prior harmonic mean of the error variance.
|
|
111
|
-
ntree : int
|
|
112
|
-
The number of trees.
|
|
113
105
|
yhat_train : array (ndpost, n)
|
|
114
106
|
The conditional posterior mean at `x_train` for each MCMC iteration.
|
|
115
107
|
yhat_train_mean : array (n,)
|
|
@@ -122,6 +114,18 @@ class BART:
|
|
|
122
114
|
The standard deviation of the error.
|
|
123
115
|
first_sigma : array (nskip,)
|
|
124
116
|
The standard deviation of the error in the burn-in phase.
|
|
117
|
+
offset : float
|
|
118
|
+
The prior mean of the latent mean function.
|
|
119
|
+
scale : float
|
|
120
|
+
The prior standard deviation of the latent mean function.
|
|
121
|
+
lamda : float
|
|
122
|
+
The prior harmonic mean of the error variance.
|
|
123
|
+
sigest : float or None
|
|
124
|
+
The estimated standard deviation of the error used to set `lamda`.
|
|
125
|
+
ntree : int
|
|
126
|
+
The number of trees.
|
|
127
|
+
maxdepth : int
|
|
128
|
+
The maximum depth of the trees.
|
|
125
129
|
|
|
126
130
|
Methods
|
|
127
131
|
-------
|
|
@@ -166,17 +170,17 @@ class BART:
|
|
|
166
170
|
y_train, y_train_fmt = self._process_response_input(y_train)
|
|
167
171
|
self._check_same_length(x_train, y_train)
|
|
168
172
|
|
|
169
|
-
lamda = self._process_noise_variance_settings(x_train, y_train, sigest, sigdf, sigquant, lamda)
|
|
170
173
|
offset = self._process_offset_settings(y_train, offset)
|
|
171
174
|
scale = self._process_scale_settings(y_train, k)
|
|
175
|
+
lamda, sigest = self._process_noise_variance_settings(x_train, y_train, sigest, sigdf, sigquant, lamda, offset)
|
|
172
176
|
|
|
173
177
|
splits, max_split = self._determine_splits(x_train, numcut)
|
|
174
178
|
x_train = self._bin_predictors(x_train, splits)
|
|
175
179
|
|
|
176
180
|
y_train = self._transform_input(y_train, offset, scale)
|
|
177
|
-
|
|
181
|
+
lamda_scaled = lamda / (scale * scale)
|
|
178
182
|
|
|
179
|
-
mcmc_state = self._setup_mcmc(x_train, y_train, max_split,
|
|
183
|
+
mcmc_state = self._setup_mcmc(x_train, y_train, max_split, lamda_scaled, sigdf, power, base, maxdepth, ntree)
|
|
180
184
|
final_state, burnin_trace, main_trace = self._run_mcmc(mcmc_state, ndpost, nskip, keepevery, printevery, seed)
|
|
181
185
|
|
|
182
186
|
sigma = self._extract_sigma(main_trace, scale)
|
|
@@ -184,8 +188,10 @@ class BART:
|
|
|
184
188
|
|
|
185
189
|
self.offset = offset
|
|
186
190
|
self.scale = scale
|
|
187
|
-
self.lamda = lamda
|
|
191
|
+
self.lamda = lamda
|
|
192
|
+
self.sigest = sigest
|
|
188
193
|
self.ntree = ntree
|
|
194
|
+
self.maxdepth = maxdepth
|
|
189
195
|
self.sigma = sigma
|
|
190
196
|
self.first_sigma = first_sigma
|
|
191
197
|
|
|
@@ -261,25 +267,25 @@ class BART:
|
|
|
261
267
|
assert get_length(x1) == get_length(x2)
|
|
262
268
|
|
|
263
269
|
@staticmethod
|
|
264
|
-
def _process_noise_variance_settings(x_train, y_train, sigest, sigdf, sigquant, lamda):
|
|
270
|
+
def _process_noise_variance_settings(x_train, y_train, sigest, sigdf, sigquant, lamda, offset):
|
|
265
271
|
if lamda is not None:
|
|
266
|
-
return lamda
|
|
272
|
+
return lamda, None
|
|
267
273
|
else:
|
|
268
274
|
if sigest is not None:
|
|
269
275
|
sigest2 = sigest * sigest
|
|
270
276
|
elif y_train.size < 2:
|
|
271
277
|
sigest2 = 1
|
|
272
278
|
elif y_train.size <= x_train.shape[0]:
|
|
273
|
-
sigest2 = jnp.var(y_train)
|
|
279
|
+
sigest2 = jnp.var(y_train - offset)
|
|
274
280
|
else:
|
|
275
|
-
_, chisq, rank, _ = jnp.linalg.lstsq(x_train.T, y_train)
|
|
281
|
+
_, chisq, rank, _ = jnp.linalg.lstsq(x_train.T, y_train - offset)
|
|
276
282
|
chisq = chisq.squeeze(0)
|
|
277
283
|
dof = len(y_train) - rank
|
|
278
284
|
sigest2 = chisq / dof
|
|
279
285
|
alpha = sigdf / 2
|
|
280
286
|
invchi2 = jaxext.scipy.stats.invgamma.ppf(sigquant, alpha) / 2
|
|
281
287
|
invchi2rid = invchi2 * sigdf
|
|
282
|
-
return sigest2 / invchi2rid
|
|
288
|
+
return sigest2 / invchi2rid, jnp.sqrt(sigest2)
|
|
283
289
|
|
|
284
290
|
@staticmethod
|
|
285
291
|
def _process_offset_settings(y_train, offset):
|
|
@@ -315,7 +321,7 @@ class BART:
|
|
|
315
321
|
p_nonterminal = base / (1 + depth).astype(float) ** power
|
|
316
322
|
sigma2_alpha = sigdf / 2
|
|
317
323
|
sigma2_beta = lamda * sigma2_alpha
|
|
318
|
-
return mcmcstep.
|
|
324
|
+
return mcmcstep.init(
|
|
319
325
|
X=x_train,
|
|
320
326
|
y=y_train,
|
|
321
327
|
max_split=max_split,
|
|
@@ -348,13 +354,6 @@ class BART:
|
|
|
348
354
|
return scale * jnp.sqrt(trace['sigma2'])
|
|
349
355
|
|
|
350
356
|
|
|
351
|
-
def _predict_debug(self, x_test):
|
|
352
|
-
from . import debug
|
|
353
|
-
x_test, x_test_fmt = self._process_predictor_input(x_test)
|
|
354
|
-
self._check_compatible_formats(x_test_fmt, self._x_train_fmt)
|
|
355
|
-
x_test = self._bin_predictors(x_test, self._splits)
|
|
356
|
-
return debug.trace_evaluate_trees(self._main_trace, x_test)
|
|
357
|
-
|
|
358
357
|
def _show_tree(self, i_sample, i_tree, print_all=False):
|
|
359
358
|
from . import debug
|
|
360
359
|
trace = self._main_trace
|
|
@@ -379,7 +378,7 @@ class BART:
|
|
|
379
378
|
def _compare_resid(self):
|
|
380
379
|
bart = self._mcmc_state
|
|
381
380
|
resid1 = bart['resid']
|
|
382
|
-
yhat = grove.
|
|
381
|
+
yhat = grove.evaluate_forest(bart['X'], bart['leaf_trees'], bart['var_trees'], bart['split_trees'], jnp.float32)
|
|
383
382
|
resid2 = bart['y'] - yhat
|
|
384
383
|
return resid1, resid2
|
|
385
384
|
|
|
@@ -421,7 +420,5 @@ class BART:
|
|
|
421
420
|
|
|
422
421
|
def _tree_goes_bad(self):
|
|
423
422
|
bad = self._check_trees().astype(bool)
|
|
424
|
-
bad_before = bad[:-1]
|
|
425
|
-
|
|
426
|
-
goes_bad = bad_after & ~bad_before
|
|
427
|
-
return jnp.pad(goes_bad, [(1, 0), (0, 0)])
|
|
423
|
+
bad_before = jnp.pad(bad[:-1], [(1, 0), (0, 0)])
|
|
424
|
+
return bad & ~bad_before
|
bartz/mcmcloop.py
CHANGED
|
@@ -100,15 +100,21 @@ def run_mcmc(bart, n_burn, n_save, n_skip, callback, key):
|
|
|
100
100
|
def inner_loop(carry, _, tracelist, burnin):
|
|
101
101
|
bart, i_total, i_skip, key = carry
|
|
102
102
|
key, subkey = random.split(key)
|
|
103
|
-
bart = mcmcstep.
|
|
103
|
+
bart = mcmcstep.step(bart, subkey)
|
|
104
104
|
callback(bart=bart, burnin=burnin, i_total=i_total, i_skip=i_skip, **callback_kw)
|
|
105
105
|
output = {key: bart[key] for key in tracelist}
|
|
106
106
|
return (bart, i_total + 1, i_skip + 1, key), output
|
|
107
107
|
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
108
|
+
if n_burn > 0:
|
|
109
|
+
carry = bart, 0, 0, key
|
|
110
|
+
burnin_loop = functools.partial(inner_loop, tracelist=tracelist_burnin, burnin=True)
|
|
111
|
+
(bart, i_total, _, key), burnin_trace = lax.scan(burnin_loop, carry, None, n_burn)
|
|
112
|
+
else:
|
|
113
|
+
i_total = 0
|
|
114
|
+
burnin_trace = {
|
|
115
|
+
key: jnp.empty((0,) + bart[key].shape, bart[key].dtype)
|
|
116
|
+
for key in tracelist_burnin
|
|
117
|
+
}
|
|
112
118
|
|
|
113
119
|
def outer_loop(carry, _):
|
|
114
120
|
bart, i_total, key = carry
|
|
@@ -148,14 +154,17 @@ def make_simple_print_callback(printevery):
|
|
|
148
154
|
prune_prop = bart['prune_prop_count'] / prop_total
|
|
149
155
|
grow_acc = bart['grow_acc_count'] / bart['grow_prop_count']
|
|
150
156
|
prune_acc = bart['prune_acc_count'] / bart['prune_prop_count']
|
|
151
|
-
n_total = n_burn + n_save
|
|
157
|
+
n_total = n_burn + n_save * n_skip
|
|
152
158
|
debug.callback(simple_print_callback_impl, burnin, i_total, n_total, grow_prop, grow_acc, prune_prop, prune_acc, printevery)
|
|
153
159
|
return callback
|
|
154
160
|
|
|
155
161
|
def simple_print_callback_impl(burnin, i_total, n_total, grow_prop, grow_acc, prune_prop, prune_acc, printevery):
|
|
156
162
|
if (i_total + 1) % printevery == 0:
|
|
157
163
|
burnin_flag = ' (burnin)' if burnin else ''
|
|
158
|
-
|
|
164
|
+
total_str = str(n_total)
|
|
165
|
+
ndigits = len(total_str)
|
|
166
|
+
i_str = str(i_total + 1).rjust(ndigits)
|
|
167
|
+
print(f'Iteration {i_str}/{total_str} '
|
|
159
168
|
f'P_grow={grow_prop:.2f} P_prune={prune_prop:.2f} '
|
|
160
169
|
f'A_grow={grow_acc:.2f} A_prune={prune_acc:.2f}{burnin_flag}')
|
|
161
170
|
|
|
@@ -177,6 +186,6 @@ def evaluate_trace(trace, X):
|
|
|
177
186
|
The predictions for each iteration of the MCMC.
|
|
178
187
|
"""
|
|
179
188
|
def loop(_, state):
|
|
180
|
-
return None, grove.
|
|
189
|
+
return None, grove.evaluate_forest(X, state['leaf_trees'], state['var_trees'], state['split_trees'], jnp.float32)
|
|
181
190
|
_, y = lax.scan(loop, None, trace)
|
|
182
191
|
return y
|