bartz 0.0__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bartz/__init__.py CHANGED
@@ -28,7 +28,7 @@ A jax implementation of BART
28
28
  See the manual at https://gattocrucco.github.io/bartz/docs
29
29
  """
30
30
 
31
- __version__ = '0.0'
31
+ from ._version import __version__
32
32
 
33
33
  from .interface import BART
34
34
 
bartz/_version.py ADDED
@@ -0,0 +1 @@
1
+ __version__ = '0.1.0'
bartz/debug.py CHANGED
@@ -7,22 +7,6 @@ from jax import lax
7
7
  from . import grove
8
8
  from . import mcmcstep
9
9
 
10
- def trace_evaluate_trees(bart, X):
11
- """
12
- Evaluate all trees, for all samples, at all x. Out axes:
13
- 0: mcmc sample
14
- 1: tree
15
- 2: X
16
- """
17
- def loop(_, bart):
18
- return None, evaluate_all_trees(X, bart['leaf_trees'], bart['var_trees'], bart['split_trees'])
19
- _, y = lax.scan(loop, None, bart)
20
- return y
21
-
22
- @functools.partial(jax.vmap, in_axes=(None, 0, 0, 0)) # vectorize over forest
23
- def evaluate_all_trees(X, leaf_trees, var_trees, split_trees):
24
- return grove.evaluate_tree_vmap_x(X, leaf_trees, var_trees, split_trees, jnp.float32)
25
-
26
10
  def print_tree(leaf_tree, var_tree, split_tree, print_all=False):
27
11
 
28
12
  tee = '├──'
@@ -65,15 +49,11 @@ def print_tree(leaf_tree, var_tree, split_tree, print_all=False):
65
49
  else:
66
50
  link = ' '
67
51
 
68
- if print_all:
69
- max_number = len(leaf_tree) - 1
70
- ndigits = len(str(max_number))
71
- number = str(index).rjust(ndigits)
72
- number = f' {number} '
73
- else:
74
- number = ''
52
+ max_number = len(leaf_tree) - 1
53
+ ndigits = len(str(max_number))
54
+ number = str(index).rjust(ndigits)
75
55
 
76
- print(f'{number}{indent}{first_indent}{link}{node_str}')
56
+ print(f' {number} {indent}{first_indent}{link}{node_str}')
77
57
 
78
58
  indent += next_indent
79
59
  unused = unused or is_leaf
@@ -101,8 +81,10 @@ def trace_depth_distr(split_trees_trace):
101
81
  return jax.vmap(forest_depth_distr)(split_trees_trace)
102
82
 
103
83
  def points_per_leaf_distr(var_tree, split_tree, X):
104
- dummy = jnp.ones(X.shape[1], jnp.uint8)
105
- _, count_tree = mcmcstep.agg_values(X, var_tree, split_tree, dummy, dummy.dtype)
84
+ traverse_tree = jax.vmap(grove.traverse_tree, in_axes=(1, None, None))
85
+ indices = traverse_tree(X, var_tree, split_tree)
86
+ count_tree = jnp.zeros(2 * split_tree.size, dtype=grove.minimal_unsigned_dtype(indices.size))
87
+ count_tree = count_tree.at[indices].add(1)
106
88
  is_leaf = grove.is_actual_leaf(split_tree, add_bottom_level=True).view(jnp.uint8)
107
89
  return jnp.bincount(count_tree, is_leaf, length=X.shape[1] + 1)
108
90
 
@@ -129,7 +111,7 @@ def check_sizes(leaf_tree, var_tree, split_tree, max_split):
129
111
  return leaf_tree.size == 2 * var_tree.size == 2 * split_tree.size
130
112
 
131
113
  def check_unused_node(leaf_tree, var_tree, split_tree, max_split):
132
- return (leaf_tree[0] == 0) & (var_tree[0] == 0) & (split_tree[0] == 0)
114
+ return (var_tree[0] == 0) & (split_tree[0] == 0)
133
115
 
134
116
  def check_leaf_values(leaf_tree, var_tree, split_tree, max_split):
135
117
  return jnp.all(jnp.isfinite(leaf_tree))
bartz/grove.py CHANGED
@@ -28,13 +28,15 @@ Functions to create and manipulate binary trees.
28
28
 
29
29
  A tree is represented with arrays as a heap. The root node is at index 1. The children nodes of a node at index :math:`i` are at indices :math:`2i` (left child) and :math:`2i + 1` (right child). The array element at index 0 is unused.
30
30
 
31
- A decision tree is represented by tree arrays: 'leaf', 'var', and 'split'. The 'leaf' array contains the values in the leaves. The 'var' array contains the axes along which the decision nodes operate. The 'split' array contains the decision boundaries.
31
+ A decision tree is represented by tree arrays: 'leaf', 'var', and 'split'.
32
32
 
33
- Whether a node is a leaf is indicated by the corresponding 'split' element being 0.
33
+ The 'leaf' array contains the values in the leaves.
34
34
 
35
- Since the nodes at the bottom can only be leaves and not decision nodes, the 'var' and 'split' arrays have half the length of the 'leaf' array.
35
+ The 'var' array contains the axes along which the decision nodes operate.
36
+
37
+ The 'split' array contains the decision boundaries. The boundaries are open on the right, i.e., a point belongs to the left child iff x < split. Whether a node is a leaf is indicated by the corresponding 'split' element being 0.
36
38
 
37
- The unused array element at index 0 is always fixed to 0 by convention.
39
+ Since the nodes at the bottom can only be leaves and not decision nodes, the 'var' and 'split' arrays have half the length of the 'leaf' array.
38
40
 
39
41
  """
40
42
 
@@ -42,6 +44,7 @@ import functools
42
44
  import math
43
45
 
44
46
  import jax
47
+
45
48
  from jax import numpy as jnp
46
49
  from jax import lax
47
50
 
@@ -63,24 +66,18 @@ def make_tree(depth, dtype):
63
66
  -------
64
67
  tree : array
65
68
  An array of zeroes with shape (2 ** depth,).
66
-
67
- Notes
68
- -----
69
- The tree is represented as a heap, with the root node at index 1, and the
70
- children of the node at index i at indices 2 * i and 2 * i + 1. The element
71
- at index 0 is unused.
72
69
  """
73
70
  return jnp.zeros(2 ** depth, dtype)
74
71
 
75
72
  def tree_depth(tree):
76
73
  """
77
- Return the maximum depth of a binary tree created by `make_tree`.
74
+ Return the maximum depth of a tree.
78
75
 
79
76
  Parameters
80
77
  ----------
81
78
  tree : array
82
- A binary tree created by `make_tree`. If the array is ND, the tree
83
- structure is assumed to be along the last axis.
79
+ A tree created by `make_tree`. If the array is ND, the tree structure is
80
+ assumed to be along the last axis.
84
81
 
85
82
  Returns
86
83
  -------
@@ -89,120 +86,97 @@ def tree_depth(tree):
89
86
  """
90
87
  return int(round(math.log2(tree.shape[-1])))
91
88
 
92
- def evaluate_tree(X, leaf_trees, var_trees, split_trees, out_dtype):
89
+ def traverse_tree(x, var_tree, split_tree):
93
90
  """
94
- Evaluate a decision tree or forest.
91
+ Find the leaf where a point falls into.
95
92
 
96
93
  Parameters
97
94
  ----------
98
- X : array (p,)
95
+ x : array (p,)
99
96
  The coordinates to evaluate the tree at.
100
- leaf_trees : array (n,) or (m, n)
101
- The leaf values of the tree or forest. If the input is a forest, the
102
- first axis is the tree index, and the values are summed.
103
- var_trees : array (n,) or (m, n)
104
- The variable indices of the tree or forest. Each index is in [0, p) and
105
- indicates which value of `X` to consider.
106
- split_trees : array (n,) or (m, n)
107
- The split values of the tree or forest. Leaf nodes are indicated by the
108
- condition `split == 0`. If non-zero, the node has children, and its left
109
- children is assigned points which satisfy `x < split`.
110
- out_dtype : dtype
111
- The dtype of the output.
97
+ var_tree : array (2 ** (d - 1),)
98
+ The decision axes of the tree.
99
+ split_tree : array (2 ** (d - 1),)
100
+ The decision boundaries of the tree.
112
101
 
113
102
  Returns
114
103
  -------
115
- out : scalar
116
- The value of the tree or forest at the given point.
104
+ index : int
105
+ The index of the leaf.
117
106
  """
118
107
 
119
- is_forest = leaf_trees.ndim == 2
120
- if is_forest:
121
- m, _ = leaf_trees.shape
122
- forest_shape = m,
123
- tree_index = jnp.arange(m, dtype=minimal_unsigned_dtype(m - 1)),
124
- else:
125
- forest_shape = ()
126
- tree_index = ()
127
-
128
108
  carry = (
129
- jnp.zeros(forest_shape, bool),
130
- jnp.zeros((), out_dtype),
131
- jnp.ones(forest_shape, minimal_unsigned_dtype(leaf_trees.shape[-1] - 1))
109
+ jnp.zeros((), bool),
110
+ jnp.ones((), minimal_unsigned_dtype(2 * var_tree.size - 1)),
132
111
  )
133
112
 
134
113
  def loop(carry, _):
135
- leaf_found, out, node_index = carry
136
-
137
- is_leaf = split_trees.at[tree_index + (node_index,)].get(mode='fill', fill_value=0) == 0
138
- leaf_value = leaf_trees[tree_index + (node_index,)]
139
- if is_forest:
140
- leaf_sum = jnp.sum(leaf_value, where=is_leaf) # TODO set dtype to large float
141
- # alternative: dot(is_leaf, leaf_value):
142
- # - maybe faster
143
- # - maybe less accurate
144
- # - fucked by nans
145
- else:
146
- leaf_sum = jnp.where(is_leaf, leaf_value, 0)
147
- out += leaf_sum
148
- leaf_found |= is_leaf
149
-
150
- split = split_trees.at[tree_index + (node_index,)].get(mode='fill', fill_value=0)
151
- var = var_trees.at[tree_index + (node_index,)].get(mode='fill', fill_value=0)
152
- x = X[var]
114
+ leaf_found, index = carry
115
+
116
+ split = split_tree.at[index].get(mode='fill', fill_value=0)
117
+ var = var_tree.at[index].get(mode='fill', fill_value=0)
153
118
 
154
- node_index <<= 1
155
- node_index += x >= split
156
- node_index = jnp.where(leaf_found, 0, node_index)
119
+ leaf_found |= split_tree.at[index].get(mode='fill', fill_value=0) == 0
120
+ child_index = (index << 1) + (x[var] >= split)
121
+ index = jnp.where(leaf_found, index, child_index)
157
122
 
158
- carry = leaf_found, out, node_index
159
- return carry, _
123
+ return (leaf_found, index), None
160
124
 
161
- depth = tree_depth(leaf_trees)
162
- (_, out, _), _ = lax.scan(loop, carry, None, depth)
163
- return out
125
+ # TODO
126
+ # - unroll (how much? 5?)
127
+ # - separate and special-case the last iteration
164
128
 
165
- def minimal_unsigned_dtype(max_value):
166
- """
167
- Return the smallest unsigned integer dtype that can represent a given
168
- maximum value.
169
- """
170
- if max_value < 2 ** 8:
171
- return jnp.uint8
172
- if max_value < 2 ** 16:
173
- return jnp.uint16
174
- if max_value < 2 ** 32:
175
- return jnp.uint32
176
- return jnp.uint64
129
+ depth = 1 + tree_depth(var_tree)
130
+ (_, index), _ = lax.scan(loop, carry, None, depth)
131
+ return index
177
132
 
178
- @functools.partial(jaxext.vmap_nodoc, in_axes=(1, None, None, None, None), out_axes=0)
179
- def evaluate_tree_vmap_x(X, leaf_trees, var_trees, split_trees, out_dtype):
133
+ def evaluate_forest(X, leaf_trees, var_trees, split_trees, dtype):
180
134
  """
181
- Evaluate a decision tree or forest over multiple points.
135
+ Evaluate a ensemble of trees at an array of points.
182
136
 
183
137
  Parameters
184
138
  ----------
185
139
  X : array (p, n)
186
- The points to evaluate the tree at.
187
- leaf_trees : array (n,) or (m, n)
140
+ The coordinates to evaluate the trees at.
141
+ leaf_trees : (m, 2 ** d)
188
142
  The leaf values of the tree or forest. If the input is a forest, the
189
143
  first axis is the tree index, and the values are summed.
190
- var_trees : array (n,) or (m, n)
191
- The variable indices of the tree or forest. Each index is in [0, p) and
192
- indicates which value of `X` to consider.
193
- split_trees : array (n,) or (m, n)
194
- The split values of the tree or forest. Leaf nodes are indicated by the
195
- condition `split == 0`. If non-zero, the node has children, and its left
196
- children is assigned points which satisfy `x < split`.
197
- out_dtype : dtype
144
+ var_trees : array (m, 2 ** (d - 1))
145
+ The decision axes of the trees.
146
+ split_trees : array (m, 2 ** (d - 1))
147
+ The decision boundaries of the trees.
148
+ dtype : dtype
198
149
  The dtype of the output.
199
150
 
200
151
  Returns
201
152
  -------
202
- out : (n,)
203
- The value of the tree or forest at each point.
153
+ out : array (n,)
154
+ The sum of the values of the trees at the points in `X`.
204
155
  """
205
- return evaluate_tree(X, leaf_trees, var_trees, split_trees, out_dtype)
156
+ indices = _traverse_forest(X, var_trees, split_trees)
157
+ ntree, _ = leaf_trees.shape
158
+ tree_index = jnp.arange(ntree, dtype=minimal_unsigned_dtype(ntree - 1))[:, None]
159
+ leaves = leaf_trees[tree_index, indices]
160
+ return jnp.sum(leaves, axis=0, dtype=dtype)
161
+ # this sum suggests to swap the vmaps, but I think it's better for X copying to keep it that way
162
+
163
+ @functools.partial(jax.vmap, in_axes=(None, 0, 0))
164
+ @functools.partial(jax.vmap, in_axes=(1, None, None))
165
+ def _traverse_forest(X, var_trees, split_trees):
166
+ return traverse_tree(X, var_trees, split_trees)
167
+
168
+ def minimal_unsigned_dtype(max_value):
169
+ """
170
+ Return the smallest unsigned integer dtype that can represent a given
171
+ maximum value.
172
+ """
173
+ if max_value < 2 ** 8:
174
+ return jnp.uint8
175
+ if max_value < 2 ** 16:
176
+ return jnp.uint16
177
+ if max_value < 2 ** 32:
178
+ return jnp.uint32
179
+ return jnp.uint64
206
180
 
207
181
  def is_actual_leaf(split_tree, *, add_bottom_level=False):
208
182
  """
@@ -239,7 +213,7 @@ def is_leaves_parent(split_tree):
239
213
  Parameters
240
214
  ----------
241
215
  split_tree : int array (2 ** (d - 1),)
242
- The splitting points of the tree.
216
+ The decision boundaries of the tree.
243
217
 
244
218
  Returns
245
219
  -------
@@ -279,24 +253,3 @@ def tree_depths(tree_length):
279
253
  depths.append(depth - 1)
280
254
  depths[0] = 0
281
255
  return jnp.array(depths, minimal_unsigned_dtype(max(depths)))
282
-
283
- def index_depth(index, tree_length):
284
- """
285
- Return the depth of a node in a binary tree.
286
-
287
- Parameters
288
- ----------
289
- index : int
290
- The index of the node.
291
- tree_length : int
292
- The length of the tree array, i.e., 2 ** d.
293
-
294
- Returns
295
- -------
296
- depth : int
297
- The depth of the node. The root node (index 1) has depth 0. The depth is
298
- the position of the most significant non-zero bit in the index. If
299
- ``index == 0``, return -1.
300
- """
301
- depths = tree_depths(tree_length)
302
- return depths[index]
bartz/interface.py CHANGED
@@ -38,7 +38,7 @@ class BART:
38
38
  Nonparametric regression with Bayesian Additive Regression Trees (BART).
39
39
 
40
40
  Regress `y_train` on `x_train` with a latent mean function represented as
41
- a sum of decision trees. The inference is carried out by estimating the
41
+ a sum of decision trees. The inference is carried out by sampling the
42
42
  posterior distribution of the tree ensemble with an MCMC.
43
43
 
44
44
  Parameters
@@ -86,7 +86,7 @@ class BART:
86
86
  predictor is binned such that its distribution in `x_train` is
87
87
  approximately uniform across bins. The number of bins is at most the
88
88
  number of unique values appearing in `x_train`, or ``numcut + 1``.
89
- Before running the algorithm, the predictors are compressed to th
89
+ Before running the algorithm, the predictors are compressed to the
90
90
  smallest integer type that fits the bin indices, so `numcut` is best set
91
91
  to the maximum value of an unsigned integer type.
92
92
  ndpost : int, default 1000
@@ -102,14 +102,6 @@ class BART:
102
102
 
103
103
  Attributes
104
104
  ----------
105
- offset : float
106
- The prior mean of the latent mean function.
107
- scale : float
108
- The prior standard deviation of the latent mean function.
109
- lamda : float
110
- The prior harmonic mean of the error variance.
111
- ntree : int
112
- The number of trees.
113
105
  yhat_train : array (ndpost, n)
114
106
  The conditional posterior mean at `x_train` for each MCMC iteration.
115
107
  yhat_train_mean : array (n,)
@@ -122,6 +114,18 @@ class BART:
122
114
  The standard deviation of the error.
123
115
  first_sigma : array (nskip,)
124
116
  The standard deviation of the error in the burn-in phase.
117
+ offset : float
118
+ The prior mean of the latent mean function.
119
+ scale : float
120
+ The prior standard deviation of the latent mean function.
121
+ lamda : float
122
+ The prior harmonic mean of the error variance.
123
+ sigest : float or None
124
+ The estimated standard deviation of the error used to set `lamda`.
125
+ ntree : int
126
+ The number of trees.
127
+ maxdepth : int
128
+ The maximum depth of the trees.
125
129
 
126
130
  Methods
127
131
  -------
@@ -166,17 +170,17 @@ class BART:
166
170
  y_train, y_train_fmt = self._process_response_input(y_train)
167
171
  self._check_same_length(x_train, y_train)
168
172
 
169
- lamda = self._process_noise_variance_settings(x_train, y_train, sigest, sigdf, sigquant, lamda)
170
173
  offset = self._process_offset_settings(y_train, offset)
171
174
  scale = self._process_scale_settings(y_train, k)
175
+ lamda, sigest = self._process_noise_variance_settings(x_train, y_train, sigest, sigdf, sigquant, lamda, offset)
172
176
 
173
177
  splits, max_split = self._determine_splits(x_train, numcut)
174
178
  x_train = self._bin_predictors(x_train, splits)
175
179
 
176
180
  y_train = self._transform_input(y_train, offset, scale)
177
- lamda = lamda / scale
181
+ lamda_scaled = lamda / (scale * scale)
178
182
 
179
- mcmc_state = self._setup_mcmc(x_train, y_train, max_split, lamda, sigdf, power, base, maxdepth, ntree)
183
+ mcmc_state = self._setup_mcmc(x_train, y_train, max_split, lamda_scaled, sigdf, power, base, maxdepth, ntree)
180
184
  final_state, burnin_trace, main_trace = self._run_mcmc(mcmc_state, ndpost, nskip, keepevery, printevery, seed)
181
185
 
182
186
  sigma = self._extract_sigma(main_trace, scale)
@@ -184,8 +188,10 @@ class BART:
184
188
 
185
189
  self.offset = offset
186
190
  self.scale = scale
187
- self.lamda = lamda * scale
191
+ self.lamda = lamda
192
+ self.sigest = sigest
188
193
  self.ntree = ntree
194
+ self.maxdepth = maxdepth
189
195
  self.sigma = sigma
190
196
  self.first_sigma = first_sigma
191
197
 
@@ -261,25 +267,25 @@ class BART:
261
267
  assert get_length(x1) == get_length(x2)
262
268
 
263
269
  @staticmethod
264
- def _process_noise_variance_settings(x_train, y_train, sigest, sigdf, sigquant, lamda):
270
+ def _process_noise_variance_settings(x_train, y_train, sigest, sigdf, sigquant, lamda, offset):
265
271
  if lamda is not None:
266
- return lamda
272
+ return lamda, None
267
273
  else:
268
274
  if sigest is not None:
269
275
  sigest2 = sigest * sigest
270
276
  elif y_train.size < 2:
271
277
  sigest2 = 1
272
278
  elif y_train.size <= x_train.shape[0]:
273
- sigest2 = jnp.var(y_train)
279
+ sigest2 = jnp.var(y_train - offset)
274
280
  else:
275
- _, chisq, rank, _ = jnp.linalg.lstsq(x_train.T, y_train)
281
+ _, chisq, rank, _ = jnp.linalg.lstsq(x_train.T, y_train - offset)
276
282
  chisq = chisq.squeeze(0)
277
283
  dof = len(y_train) - rank
278
284
  sigest2 = chisq / dof
279
285
  alpha = sigdf / 2
280
286
  invchi2 = jaxext.scipy.stats.invgamma.ppf(sigquant, alpha) / 2
281
287
  invchi2rid = invchi2 * sigdf
282
- return sigest2 / invchi2rid
288
+ return sigest2 / invchi2rid, jnp.sqrt(sigest2)
283
289
 
284
290
  @staticmethod
285
291
  def _process_offset_settings(y_train, offset):
@@ -315,7 +321,7 @@ class BART:
315
321
  p_nonterminal = base / (1 + depth).astype(float) ** power
316
322
  sigma2_alpha = sigdf / 2
317
323
  sigma2_beta = lamda * sigma2_alpha
318
- return mcmcstep.make_bart(
324
+ return mcmcstep.init(
319
325
  X=x_train,
320
326
  y=y_train,
321
327
  max_split=max_split,
@@ -348,13 +354,6 @@ class BART:
348
354
  return scale * jnp.sqrt(trace['sigma2'])
349
355
 
350
356
 
351
- def _predict_debug(self, x_test):
352
- from . import debug
353
- x_test, x_test_fmt = self._process_predictor_input(x_test)
354
- self._check_compatible_formats(x_test_fmt, self._x_train_fmt)
355
- x_test = self._bin_predictors(x_test, self._splits)
356
- return debug.trace_evaluate_trees(self._main_trace, x_test)
357
-
358
357
  def _show_tree(self, i_sample, i_tree, print_all=False):
359
358
  from . import debug
360
359
  trace = self._main_trace
@@ -379,7 +378,7 @@ class BART:
379
378
  def _compare_resid(self):
380
379
  bart = self._mcmc_state
381
380
  resid1 = bart['resid']
382
- yhat = grove.evaluate_tree_vmap_x(bart['X'], bart['leaf_trees'], bart['var_trees'], bart['split_trees'], jnp.float32)
381
+ yhat = grove.evaluate_forest(bart['X'], bart['leaf_trees'], bart['var_trees'], bart['split_trees'], jnp.float32)
383
382
  resid2 = bart['y'] - yhat
384
383
  return resid1, resid2
385
384
 
@@ -421,7 +420,5 @@ class BART:
421
420
 
422
421
  def _tree_goes_bad(self):
423
422
  bad = self._check_trees().astype(bool)
424
- bad_before = bad[:-1]
425
- bad_after = bad[1:]
426
- goes_bad = bad_after & ~bad_before
427
- return jnp.pad(goes_bad, [(1, 0), (0, 0)])
423
+ bad_before = jnp.pad(bad[:-1], [(1, 0), (0, 0)])
424
+ return bad & ~bad_before
bartz/mcmcloop.py CHANGED
@@ -100,15 +100,21 @@ def run_mcmc(bart, n_burn, n_save, n_skip, callback, key):
100
100
  def inner_loop(carry, _, tracelist, burnin):
101
101
  bart, i_total, i_skip, key = carry
102
102
  key, subkey = random.split(key)
103
- bart = mcmcstep.mcmc_step(bart, subkey)
103
+ bart = mcmcstep.step(bart, subkey)
104
104
  callback(bart=bart, burnin=burnin, i_total=i_total, i_skip=i_skip, **callback_kw)
105
105
  output = {key: bart[key] for key in tracelist}
106
106
  return (bart, i_total + 1, i_skip + 1, key), output
107
107
 
108
- # TODO avoid invoking this altogether if burnin is 0 to shorten compilation time & size
109
- carry = bart, 0, 0, key
110
- burnin_loop = functools.partial(inner_loop, tracelist=tracelist_burnin, burnin=True)
111
- (bart, i_total, _, key), burnin_trace = lax.scan(burnin_loop, carry, None, n_burn)
108
+ if n_burn > 0:
109
+ carry = bart, 0, 0, key
110
+ burnin_loop = functools.partial(inner_loop, tracelist=tracelist_burnin, burnin=True)
111
+ (bart, i_total, _, key), burnin_trace = lax.scan(burnin_loop, carry, None, n_burn)
112
+ else:
113
+ i_total = 0
114
+ burnin_trace = {
115
+ key: jnp.empty((0,) + bart[key].shape, bart[key].dtype)
116
+ for key in tracelist_burnin
117
+ }
112
118
 
113
119
  def outer_loop(carry, _):
114
120
  bart, i_total, key = carry
@@ -148,14 +154,17 @@ def make_simple_print_callback(printevery):
148
154
  prune_prop = bart['prune_prop_count'] / prop_total
149
155
  grow_acc = bart['grow_acc_count'] / bart['grow_prop_count']
150
156
  prune_acc = bart['prune_acc_count'] / bart['prune_prop_count']
151
- n_total = n_burn + n_save
157
+ n_total = n_burn + n_save * n_skip
152
158
  debug.callback(simple_print_callback_impl, burnin, i_total, n_total, grow_prop, grow_acc, prune_prop, prune_acc, printevery)
153
159
  return callback
154
160
 
155
161
  def simple_print_callback_impl(burnin, i_total, n_total, grow_prop, grow_acc, prune_prop, prune_acc, printevery):
156
162
  if (i_total + 1) % printevery == 0:
157
163
  burnin_flag = ' (burnin)' if burnin else ''
158
- print(f'Iteration {i_total + 1:4d}/{n_total:d} '
164
+ total_str = str(n_total)
165
+ ndigits = len(total_str)
166
+ i_str = str(i_total + 1).rjust(ndigits)
167
+ print(f'Iteration {i_str}/{total_str} '
159
168
  f'P_grow={grow_prop:.2f} P_prune={prune_prop:.2f} '
160
169
  f'A_grow={grow_acc:.2f} A_prune={prune_acc:.2f}{burnin_flag}')
161
170
 
@@ -177,6 +186,6 @@ def evaluate_trace(trace, X):
177
186
  The predictions for each iteration of the MCMC.
178
187
  """
179
188
  def loop(_, state):
180
- return None, grove.evaluate_tree_vmap_x(X, state['leaf_trees'], state['var_trees'], state['split_trees'], jnp.float32)
189
+ return None, grove.evaluate_forest(X, state['leaf_trees'], state['var_trees'], state['split_trees'], jnp.float32)
181
190
  _, y = lax.scan(loop, None, trace)
182
191
  return y