bartz 0.1.0__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,4 @@
1
- # bartz/src/bartz/interface.py
1
+ # bartz/src/bartz/BART.py
2
2
  #
3
3
  # Copyright (c) 2024, Giacomo Petrillo
4
4
  #
@@ -33,7 +33,7 @@ from . import mcmcstep
33
33
  from . import mcmcloop
34
34
  from . import prepcovars
35
35
 
36
- class BART:
36
+ class gbart:
37
37
  """
38
38
  Nonparametric regression with Bayesian Additive Regression Trees (BART).
39
39
 
@@ -133,7 +133,7 @@ class BART:
133
133
 
134
134
  Notes
135
135
  -----
136
- This interface imitates the function `wbart` from the R package `BART
136
+ This interface imitates the function `gbart` from the R package `BART
137
137
  <https://cran.r-project.org/package=BART>`_, but with these differences:
138
138
 
139
139
  - If `x_train` and `x_test` are matrices, they have one predictor per row
@@ -142,6 +142,7 @@ class BART:
142
142
  - `usequants` is always `True`.
143
143
  - `rm_const` is always `False`.
144
144
  - The default `numcut` is 255 instead of 100.
145
+ - A lot of functionality is missing (variable selection, discrete response).
145
146
  - There are some additional attributes, and some missing.
146
147
  """
147
148
 
bartz/__init__.py CHANGED
@@ -30,6 +30,11 @@ See the manual at https://gattocrucco.github.io/bartz/docs
30
30
 
31
31
  from ._version import __version__
32
32
 
33
- from .interface import BART
33
+ from . import BART
34
34
 
35
35
  from . import debug
36
+ from . import grove
37
+ from . import mcmcstep
38
+ from . import mcmcloop
39
+ from . import prepcovars
40
+ from . import jaxext
bartz/_version.py CHANGED
@@ -1 +1 @@
1
- __version__ = '0.1.0'
1
+ __version__ = '0.2.1'
bartz/debug.py CHANGED
@@ -6,6 +6,7 @@ from jax import lax
6
6
 
7
7
  from . import grove
8
8
  from . import mcmcstep
9
+ from . import jaxext
9
10
 
10
11
  def print_tree(leaf_tree, var_tree, split_tree, print_all=False):
11
12
 
@@ -83,7 +84,7 @@ def trace_depth_distr(split_trees_trace):
83
84
  def points_per_leaf_distr(var_tree, split_tree, X):
84
85
  traverse_tree = jax.vmap(grove.traverse_tree, in_axes=(1, None, None))
85
86
  indices = traverse_tree(X, var_tree, split_tree)
86
- count_tree = jnp.zeros(2 * split_tree.size, dtype=grove.minimal_unsigned_dtype(indices.size))
87
+ count_tree = jnp.zeros(2 * split_tree.size, dtype=jaxext.minimal_unsigned_dtype(indices.size))
87
88
  count_tree = count_tree.at[indices].add(1)
88
89
  is_leaf = grove.is_actual_leaf(split_tree, add_bottom_level=True).view(jnp.uint8)
89
90
  return jnp.bincount(count_tree, is_leaf, length=X.shape[1] + 1)
@@ -103,7 +104,7 @@ def trace_points_per_leaf_distr(bart, X):
103
104
  return distr
104
105
 
105
106
  def check_types(leaf_tree, var_tree, split_tree, max_split):
106
- expected_var_dtype = grove.minimal_unsigned_dtype(max_split.size - 1)
107
+ expected_var_dtype = jaxext.minimal_unsigned_dtype(max_split.size - 1)
107
108
  expected_split_dtype = max_split.dtype
108
109
  return var_tree.dtype == expected_var_dtype and split_tree.dtype == expected_split_dtype
109
110
 
@@ -117,7 +118,7 @@ def check_leaf_values(leaf_tree, var_tree, split_tree, max_split):
117
118
  return jnp.all(jnp.isfinite(leaf_tree))
118
119
 
119
120
  def check_stray_nodes(leaf_tree, var_tree, split_tree, max_split):
120
- index = jnp.arange(2 * split_tree.size, dtype=grove.minimal_unsigned_dtype(2 * split_tree.size - 1))
121
+ index = jnp.arange(2 * split_tree.size, dtype=jaxext.minimal_unsigned_dtype(2 * split_tree.size - 1))
121
122
  parent_index = index >> 1
122
123
  is_not_leaf = split_tree.at[index].get(mode='fill', fill_value=0) != 0
123
124
  parent_is_leaf = split_tree[parent_index] == 0
@@ -134,7 +135,7 @@ check_functions = [
134
135
  ]
135
136
 
136
137
  def check_tree(leaf_tree, var_tree, split_tree, max_split):
137
- error_type = grove.minimal_unsigned_dtype(2 ** len(check_functions) - 1)
138
+ error_type = jaxext.minimal_unsigned_dtype(2 ** len(check_functions) - 1)
138
139
  error = error_type(0)
139
140
  for i, func in enumerate(check_functions):
140
141
  ok = func(leaf_tree, var_tree, split_tree, max_split)
bartz/grove.py CHANGED
@@ -44,7 +44,6 @@ import functools
44
44
  import math
45
45
 
46
46
  import jax
47
-
48
47
  from jax import numpy as jnp
49
48
  from jax import lax
50
49
 
@@ -107,29 +106,47 @@ def traverse_tree(x, var_tree, split_tree):
107
106
 
108
107
  carry = (
109
108
  jnp.zeros((), bool),
110
- jnp.ones((), minimal_unsigned_dtype(2 * var_tree.size - 1)),
109
+ jnp.ones((), jaxext.minimal_unsigned_dtype(2 * var_tree.size - 1)),
111
110
  )
112
111
 
113
112
  def loop(carry, _):
114
113
  leaf_found, index = carry
115
114
 
116
- split = split_tree.at[index].get(mode='fill', fill_value=0)
117
- var = var_tree.at[index].get(mode='fill', fill_value=0)
115
+ split = split_tree[index]
116
+ var = var_tree[index]
118
117
 
119
- leaf_found |= split_tree.at[index].get(mode='fill', fill_value=0) == 0
118
+ leaf_found |= split == 0
120
119
  child_index = (index << 1) + (x[var] >= split)
121
120
  index = jnp.where(leaf_found, index, child_index)
122
121
 
123
122
  return (leaf_found, index), None
124
123
 
125
- # TODO
126
- # - unroll (how much? 5?)
127
- # - separate and special-case the last iteration
128
-
129
- depth = 1 + tree_depth(var_tree)
130
- (_, index), _ = lax.scan(loop, carry, None, depth)
124
+ depth = tree_depth(var_tree)
125
+ (_, index), _ = lax.scan(loop, carry, None, depth, unroll=16)
131
126
  return index
132
127
 
128
+ @functools.partial(jaxext.vmap_nodoc, in_axes=(None, 0, 0))
129
+ @functools.partial(jaxext.vmap_nodoc, in_axes=(1, None, None))
130
+ def traverse_forest(X, var_trees, split_trees):
131
+ """
132
+ Find the leaves where points fall into.
133
+
134
+ Parameters
135
+ ----------
136
+ X : array (p, n)
137
+ The coordinates to evaluate the trees at.
138
+ var_trees : array (m, 2 ** (d - 1))
139
+ The decision axes of the trees.
140
+ split_trees : array (m, 2 ** (d - 1))
141
+ The decision boundaries of the trees.
142
+
143
+ Returns
144
+ -------
145
+ indices : array (m, n)
146
+ The indices of the leaves.
147
+ """
148
+ return traverse_tree(X, var_trees, split_trees)
149
+
133
150
  def evaluate_forest(X, leaf_trees, var_trees, split_trees, dtype):
134
151
  """
135
152
  Evaluate a ensemble of trees at an array of points.
@@ -138,7 +155,7 @@ def evaluate_forest(X, leaf_trees, var_trees, split_trees, dtype):
138
155
  ----------
139
156
  X : array (p, n)
140
157
  The coordinates to evaluate the trees at.
141
- leaf_trees : (m, 2 ** d)
158
+ leaf_trees : array (m, 2 ** d)
142
159
  The leaf values of the tree or forest. If the input is a forest, the
143
160
  first axis is the tree index, and the values are summed.
144
161
  var_trees : array (m, 2 ** (d - 1))
@@ -153,30 +170,13 @@ def evaluate_forest(X, leaf_trees, var_trees, split_trees, dtype):
153
170
  out : array (n,)
154
171
  The sum of the values of the trees at the points in `X`.
155
172
  """
156
- indices = _traverse_forest(X, var_trees, split_trees)
173
+ indices = traverse_forest(X, var_trees, split_trees)
157
174
  ntree, _ = leaf_trees.shape
158
- tree_index = jnp.arange(ntree, dtype=minimal_unsigned_dtype(ntree - 1))[:, None]
175
+ tree_index = jnp.arange(ntree, dtype=jaxext.minimal_unsigned_dtype(ntree - 1))[:, None]
159
176
  leaves = leaf_trees[tree_index, indices]
160
177
  return jnp.sum(leaves, axis=0, dtype=dtype)
161
- # this sum suggests to swap the vmaps, but I think it's better for X copying to keep it that way
162
-
163
- @functools.partial(jax.vmap, in_axes=(None, 0, 0))
164
- @functools.partial(jax.vmap, in_axes=(1, None, None))
165
- def _traverse_forest(X, var_trees, split_trees):
166
- return traverse_tree(X, var_trees, split_trees)
167
-
168
- def minimal_unsigned_dtype(max_value):
169
- """
170
- Return the smallest unsigned integer dtype that can represent a given
171
- maximum value.
172
- """
173
- if max_value < 2 ** 8:
174
- return jnp.uint8
175
- if max_value < 2 ** 16:
176
- return jnp.uint16
177
- if max_value < 2 ** 32:
178
- return jnp.uint32
179
- return jnp.uint64
178
+ # this sum suggests to swap the vmaps, but I think it's better for X
179
+ # copying to keep it that way
180
180
 
181
181
  def is_actual_leaf(split_tree, *, add_bottom_level=False):
182
182
  """
@@ -200,7 +200,7 @@ def is_actual_leaf(split_tree, *, add_bottom_level=False):
200
200
  if add_bottom_level:
201
201
  size *= 2
202
202
  is_leaf = jnp.concatenate([is_leaf, jnp.ones_like(is_leaf)])
203
- index = jnp.arange(size, dtype=minimal_unsigned_dtype(size - 1))
203
+ index = jnp.arange(size, dtype=jaxext.minimal_unsigned_dtype(size - 1))
204
204
  parent_index = index >> 1
205
205
  parent_nonleaf = split_tree[parent_index].astype(bool)
206
206
  parent_nonleaf = parent_nonleaf.at[1].set(True)
@@ -220,7 +220,7 @@ def is_leaves_parent(split_tree):
220
220
  is_leaves_parent : bool array (2 ** (d - 1),)
221
221
  The mask indicating which nodes have leaf children.
222
222
  """
223
- index = jnp.arange(split_tree.size, dtype=minimal_unsigned_dtype(2 * split_tree.size - 1))
223
+ index = jnp.arange(split_tree.size, dtype=jaxext.minimal_unsigned_dtype(2 * split_tree.size - 1))
224
224
  left_index = index << 1 # left child
225
225
  right_index = left_index + 1 # right child
226
226
  left_leaf = split_tree.at[left_index].get(mode='fill', fill_value=0) == 0
@@ -252,4 +252,4 @@ def tree_depths(tree_length):
252
252
  depth += 1
253
253
  depths.append(depth - 1)
254
254
  depths[0] = 0
255
- return jnp.array(depths, minimal_unsigned_dtype(max(depths)))
255
+ return jnp.array(depths, jaxext.minimal_unsigned_dtype(max(depths)))
bartz/jaxext.py CHANGED
@@ -10,10 +10,10 @@
10
10
  # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11
11
  # copies of the Software, and to permit persons to whom the Software is
12
12
  # furnished to do so, subject to the following conditions:
13
- #
13
+ #
14
14
  # The above copyright notice and this permission notice shall be included in all
15
15
  # copies or substantial portions of the Software.
16
- #
16
+ #
17
17
  # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18
18
  # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19
19
  # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
@@ -23,12 +23,19 @@
23
23
  # SOFTWARE.
24
24
 
25
25
  import functools
26
+ import math
27
+ import warnings
26
28
 
27
29
  from scipy import special
28
30
  import jax
29
31
  from jax import numpy as jnp
32
+ from jax import tree_util
33
+ from jax import lax
30
34
 
31
35
  def float_type(*args):
36
+ """
37
+ Determine the jax floating point result type given operands/types.
38
+ """
32
39
  t = jnp.result_type(*args)
33
40
  return jnp.sin(jnp.empty(0, t)).dtype
34
41
 
@@ -39,8 +46,8 @@ def castto(func, type):
39
46
  return newfunc
40
47
 
41
48
  def pure_callback_ufunc(callback, dtype, *args, excluded=None, **kwargs):
42
- """ version of jax.pure_callback that deals correctly with ufuncs,
43
- see https://github.com/google/jax/issues/17187 """
49
+ """ version of `jax.pure_callback` that deals correctly with ufuncs,
50
+ see `<https://github.com/google/jax/issues/17187>`_ """
44
51
  if excluded is None:
45
52
  excluded = ()
46
53
  shape = jnp.broadcast_shapes(*(
@@ -63,6 +70,7 @@ class scipy:
63
70
 
64
71
  class special:
65
72
 
73
+ @functools.wraps(special.gammainccinv)
66
74
  def gammainccinv(a, y):
67
75
  a = jnp.asarray(a)
68
76
  y = jnp.asarray(y)
@@ -73,13 +81,261 @@ class scipy:
73
81
  class stats:
74
82
 
75
83
  class invgamma:
76
-
84
+
77
85
  def ppf(q, a):
78
86
  return 1 / scipy.special.gammainccinv(a, q)
79
87
 
80
88
  @functools.wraps(jax.vmap)
81
89
  def vmap_nodoc(fun, *args, **kw):
90
+ """
91
+ Version of `jax.vmap` that preserves the docstring of the input function.
92
+ """
82
93
  doc = fun.__doc__
83
94
  fun = jax.vmap(fun, *args, **kw)
84
95
  fun.__doc__ = doc
85
96
  return fun
97
+
98
+ def huge_value(x):
99
+ """
100
+ Return the maximum value that can be stored in `x`.
101
+
102
+ Parameters
103
+ ----------
104
+ x : array
105
+ A numerical numpy or jax array.
106
+
107
+ Returns
108
+ -------
109
+ maxval : scalar
110
+ The maximum value allowed by `x`'s type (+inf for floats).
111
+ """
112
+ if jnp.issubdtype(x.dtype, jnp.integer):
113
+ return jnp.iinfo(x.dtype).max
114
+ else:
115
+ return jnp.inf
116
+
117
+ def minimal_unsigned_dtype(max_value):
118
+ """
119
+ Return the smallest unsigned integer dtype that can represent a given
120
+ maximum value (inclusive).
121
+ """
122
+ if max_value < 2 ** 8:
123
+ return jnp.uint8
124
+ if max_value < 2 ** 16:
125
+ return jnp.uint16
126
+ if max_value < 2 ** 32:
127
+ return jnp.uint32
128
+ return jnp.uint64
129
+
130
+ def signed_to_unsigned(int_dtype):
131
+ """
132
+ Map a signed integer type to its unsigned counterpart. Unsigned types are
133
+ passed through.
134
+ """
135
+ assert jnp.issubdtype(int_dtype, jnp.integer)
136
+ if jnp.issubdtype(int_dtype, jnp.unsignedinteger):
137
+ return int_dtype
138
+ if int_dtype == jnp.int8:
139
+ return jnp.uint8
140
+ if int_dtype == jnp.int16:
141
+ return jnp.uint16
142
+ if int_dtype == jnp.int32:
143
+ return jnp.uint32
144
+ if int_dtype == jnp.int64:
145
+ return jnp.uint64
146
+
147
+ def ensure_unsigned(x):
148
+ """
149
+ If x has signed integer type, cast it to the unsigned dtype of the same size.
150
+ """
151
+ return x.astype(signed_to_unsigned(x.dtype))
152
+
153
+ @functools.partial(jax.jit, static_argnums=(1,))
154
+ def unique(x, size, fill_value):
155
+ """
156
+ Restricted version of `jax.numpy.unique` that uses less memory.
157
+
158
+ Parameters
159
+ ----------
160
+ x : 1d array
161
+ The input array.
162
+ size : int
163
+ The length of the output.
164
+ fill_value : scalar
165
+ The value to fill the output with if `size` is greater than the number
166
+ of unique values in `x`.
167
+
168
+ Returns
169
+ -------
170
+ out : array (size,)
171
+ The unique values in `x`, sorted, and right-padded with `fill_value`.
172
+ actual_length : int
173
+ The number of used values in `out`.
174
+ """
175
+ if x.size == 0:
176
+ return jnp.full(size, fill_value, x.dtype), 0
177
+ if size == 0:
178
+ return jnp.empty(0, x.dtype), 0
179
+ x = jnp.sort(x)
180
+ def loop(carry, x):
181
+ i_out, i_in, last, out = carry
182
+ i_out = jnp.where(x == last, i_out, i_out + 1)
183
+ out = out.at[i_out].set(x)
184
+ return (i_out, i_in + 1, x, out), None
185
+ carry = 0, 0, x[0], jnp.full(size, fill_value, x.dtype)
186
+ (actual_length, _, _, out), _ = jax.lax.scan(loop, carry, x[:size])
187
+ return out, actual_length + 1
188
+
189
+ def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False):
190
+ """
191
+ Batch a function such that each batch is smaller than a threshold.
192
+
193
+ Parameters
194
+ ----------
195
+ func : callable
196
+ A jittable function with positional arguments only, with inputs and
197
+ outputs pytrees of arrays.
198
+ max_io_nbytes : int
199
+ The maximum number of input + output bytes in each batch.
200
+ in_axes : pytree of ints, default 0
201
+ A tree matching the structure of the function input, indicating along
202
+ which axes each array should be batched. If a single integer, it is
203
+ used for all arrays.
204
+ out_axes : pytree of ints, default 0
205
+ The same for outputs.
206
+ return_nbatches : bool, default False
207
+ If True, the number of batches is returned as a second output.
208
+
209
+ Returns
210
+ -------
211
+ batched_func : callable
212
+ A function with the same signature as `func`, but that processes the
213
+ input and output in batches in a loop.
214
+ """
215
+
216
+ def expand_axes(axes, tree):
217
+ if isinstance(axes, int):
218
+ return tree_util.tree_map(lambda _: axes, tree)
219
+ return tree_util.tree_map(lambda _, axis: axis, tree, axes)
220
+
221
+ def extract_size(axes, tree):
222
+ sizes = tree_util.tree_map(lambda x, axis: x.shape[axis], tree, axes)
223
+ sizes, _ = tree_util.tree_flatten(sizes)
224
+ assert all(s == sizes[0] for s in sizes)
225
+ return sizes[0]
226
+
227
+ def sum_nbytes(tree):
228
+ def nbytes(x):
229
+ return math.prod(x.shape) * x.dtype.itemsize
230
+ return tree_util.tree_reduce(lambda size, x: size + nbytes(x), tree, 0)
231
+
232
+ def next_divisor_small(dividend, min_divisor):
233
+ for divisor in range(min_divisor, int(math.sqrt(dividend)) + 1):
234
+ if dividend % divisor == 0:
235
+ return divisor
236
+ return dividend
237
+
238
+ def next_divisor_large(dividend, min_divisor):
239
+ max_inv_divisor = dividend // min_divisor
240
+ for inv_divisor in range(max_inv_divisor, 0, -1):
241
+ if dividend % inv_divisor == 0:
242
+ return dividend // inv_divisor
243
+ return dividend
244
+
245
+ def next_divisor(dividend, min_divisor):
246
+ if min_divisor * min_divisor <= dividend:
247
+ return next_divisor_small(dividend, min_divisor)
248
+ return next_divisor_large(dividend, min_divisor)
249
+
250
+ def move_axes_out(axes, tree):
251
+ def move_axis_out(axis, x):
252
+ if axis != 0:
253
+ return jnp.moveaxis(x, axis, 0)
254
+ return x
255
+ return tree_util.tree_map(move_axis_out, axes, tree)
256
+
257
+ def move_axes_in(axes, tree):
258
+ def move_axis_in(axis, x):
259
+ if axis != 0:
260
+ return jnp.moveaxis(x, 0, axis)
261
+ return x
262
+ return tree_util.tree_map(move_axis_in, axes, tree)
263
+
264
+ def batch(tree, nbatches):
265
+ def batch(x):
266
+ return x.reshape((nbatches, x.shape[0] // nbatches) + x.shape[1:])
267
+ return tree_util.tree_map(batch, tree)
268
+
269
+ def unbatch(tree):
270
+ def unbatch(x):
271
+ return x.reshape((x.shape[0] * x.shape[1],) + x.shape[2:])
272
+ return tree_util.tree_map(unbatch, tree)
273
+
274
+ def check_same(tree1, tree2):
275
+ def check_same(x1, x2):
276
+ assert x1.shape == x2.shape
277
+ assert x1.dtype == x2.dtype
278
+ tree_util.tree_map(check_same, tree1, tree2)
279
+
280
+ initial_in_axes = in_axes
281
+ initial_out_axes = out_axes
282
+
283
+ @jax.jit
284
+ @functools.wraps(func)
285
+ def batched_func(*args):
286
+ example_result = jax.eval_shape(func, *args)
287
+
288
+ in_axes = expand_axes(initial_in_axes, args)
289
+ out_axes = expand_axes(initial_out_axes, example_result)
290
+
291
+ in_size = extract_size(in_axes, args)
292
+ out_size = extract_size(out_axes, example_result)
293
+ assert in_size == out_size
294
+ size = in_size
295
+
296
+ total_nbytes = sum_nbytes(args) + sum_nbytes(example_result)
297
+ min_nbatches = total_nbytes // max_io_nbytes + bool(total_nbytes % max_io_nbytes)
298
+ nbatches = next_divisor(size, min_nbatches)
299
+ assert 1 <= nbatches <= size
300
+ assert size % nbatches == 0
301
+ assert total_nbytes % nbatches == 0
302
+
303
+ batch_nbytes = total_nbytes // nbatches
304
+ if batch_nbytes > max_io_nbytes:
305
+ assert size == nbatches
306
+ warnings.warn(f'batch_nbytes = {batch_nbytes} > max_io_nbytes = {max_io_nbytes}')
307
+
308
+ def loop(_, args):
309
+ args = move_axes_in(in_axes, args)
310
+ result = func(*args)
311
+ result = move_axes_out(out_axes, result)
312
+ return None, result
313
+
314
+ args = move_axes_out(in_axes, args)
315
+ args = batch(args, nbatches)
316
+ _, result = lax.scan(loop, None, args)
317
+ result = unbatch(result)
318
+ result = move_axes_in(out_axes, result)
319
+
320
+ check_same(example_result, result)
321
+
322
+ if return_nbatches:
323
+ return result, nbatches
324
+ return result
325
+
326
+ return batched_func
327
+
328
+ @tree_util.register_pytree_node_class
329
+ class LeafDict(dict):
330
+ """ dictionary that acts as a leaf in jax pytrees, to store compile-time
331
+ values """
332
+
333
+ def tree_flatten(self):
334
+ return (), self
335
+
336
+ @classmethod
337
+ def tree_unflatten(cls, aux_data, children):
338
+ return aux_data
339
+
340
+ def __repr__(self):
341
+ return f'{__class__.__name__}({super().__repr__()})'
bartz/mcmcloop.py CHANGED
@@ -52,7 +52,7 @@ def run_mcmc(bart, n_burn, n_save, n_skip, callback, key):
52
52
  n_save : int
53
53
  The number of iterations to save.
54
54
  n_skip : int
55
- The number of iterations to skip between each saved iteration.
55
+ The number of iterations to skip between each saved iteration, plus 1.
56
56
  callback : callable
57
57
  An arbitrary function run at each iteration, called with the following
58
58
  arguments, passed by keyword:
@@ -105,16 +105,19 @@ def run_mcmc(bart, n_burn, n_save, n_skip, callback, key):
105
105
  output = {key: bart[key] for key in tracelist}
106
106
  return (bart, i_total + 1, i_skip + 1, key), output
107
107
 
108
+ def empty_trace(bart, tracelist):
109
+ return {
110
+ key: jnp.empty((0,) + bart[key].shape, bart[key].dtype)
111
+ for key in tracelist
112
+ }
113
+
108
114
  if n_burn > 0:
109
115
  carry = bart, 0, 0, key
110
116
  burnin_loop = functools.partial(inner_loop, tracelist=tracelist_burnin, burnin=True)
111
117
  (bart, i_total, _, key), burnin_trace = lax.scan(burnin_loop, carry, None, n_burn)
112
118
  else:
113
119
  i_total = 0
114
- burnin_trace = {
115
- key: jnp.empty((0,) + bart[key].shape, bart[key].dtype)
116
- for key in tracelist_burnin
117
- }
120
+ burnin_trace = empty_trace(bart, tracelist_burnin)
118
121
 
119
122
  def outer_loop(carry, _):
120
123
  bart, i_total, key = carry
@@ -124,8 +127,11 @@ def run_mcmc(bart, n_burn, n_save, n_skip, callback, key):
124
127
  output = {key: bart[key] for key in tracelist_main}
125
128
  return (bart, i_total, key), output
126
129
 
127
- carry = bart, i_total, key
128
- (bart, _, _), main_trace = lax.scan(outer_loop, carry, None, n_save)
130
+ if n_save > 0:
131
+ carry = bart, i_total, key
132
+ (bart, _, _), main_trace = lax.scan(outer_loop, carry, None, n_save)
133
+ else:
134
+ main_trace = empty_trace(bart, tracelist_main)
129
135
 
130
136
  return bart, burnin_trace, main_trace
131
137
 
@@ -133,7 +139,8 @@ def run_mcmc(bart, n_burn, n_save, n_skip, callback, key):
133
139
 
134
140
  @functools.lru_cache
135
141
  # cache to make the callback function object unique, such that the jit
136
- # of run_mcmc recognizes it
142
+ # of run_mcmc recognizes it => with the callback state, I can make
143
+ # printevery a runtime quantity
137
144
  def make_simple_print_callback(printevery):
138
145
  """
139
146
  Create a logging callback function for MCMC iterations.
@@ -155,11 +162,12 @@ def make_simple_print_callback(printevery):
155
162
  grow_acc = bart['grow_acc_count'] / bart['grow_prop_count']
156
163
  prune_acc = bart['prune_acc_count'] / bart['prune_prop_count']
157
164
  n_total = n_burn + n_save * n_skip
158
- debug.callback(simple_print_callback_impl, burnin, i_total, n_total, grow_prop, grow_acc, prune_prop, prune_acc, printevery)
165
+ printcond = (i_total + 1) % printevery == 0
166
+ debug.callback(_simple_print_callback, burnin, i_total, n_total, grow_prop, grow_acc, prune_prop, prune_acc, printcond)
159
167
  return callback
160
168
 
161
- def simple_print_callback_impl(burnin, i_total, n_total, grow_prop, grow_acc, prune_prop, prune_acc, printevery):
162
- if (i_total + 1) % printevery == 0:
169
+ def _simple_print_callback(burnin, i_total, n_total, grow_prop, grow_acc, prune_prop, prune_acc, printcond):
170
+ if printcond:
163
171
  burnin_flag = ' (burnin)' if burnin else ''
164
172
  total_str = str(n_total)
165
173
  ndigits = len(total_str)
bartz/mcmcstep.py CHANGED
@@ -34,7 +34,6 @@ range of possible values.
34
34
  """
35
35
 
36
36
  import functools
37
- import math
38
37
 
39
38
  import jax
40
39
  from jax import random
@@ -55,6 +54,7 @@ def init(*,
55
54
  small_float=jnp.float32,
56
55
  large_float=jnp.float32,
57
56
  min_points_per_leaf=None,
57
+ suffstat_batch_size='auto',
58
58
  ):
59
59
  """
60
60
  Make a BART posterior sampling MCMC initial state.
@@ -82,6 +82,10 @@ def init(*,
82
82
  The dtype for scalars, small arrays, and arrays which require accuracy.
83
83
  min_points_per_leaf : int, optional
84
84
  The minimum number of data points in a leaf node. 0 if not specified.
85
+ suffstat_batch_size : int, None, str, default 'auto'
86
+ The batch size for computing sufficient statistics. `None` for no
87
+ batching. If 'auto', pick a value based on the device of `y`, or the
88
+ default device.
85
89
 
86
90
  Returns
87
91
  -------
@@ -104,8 +108,9 @@ def init(*,
104
108
  The number of grow/prune proposals made during one full MCMC cycle.
105
109
  'grow_acc_count', 'prune_acc_count' : int
106
110
  The number of grow/prune moves accepted during one full MCMC cycle.
107
- 'p_nonterminal' : large_float array (d - 1,)
108
- The probability of a nonterminal node at each depth.
111
+ 'p_nonterminal' : large_float array (d,)
112
+ The probability of a nonterminal node at each depth, padded with a
113
+ zero.
109
114
  'sigma2_alpha' : large_float
110
115
  The shape parameter of the inverse gamma prior on the noise variance.
111
116
  'sigma2_beta' : large_float
@@ -121,18 +126,36 @@ def init(*,
121
126
  'affluence_trees' : bool array (num_trees, 2 ** (d - 1)) or None
122
127
  Whether a non-bottom leaf nodes contains twice `min_points_per_leaf`
123
128
  datapoints. If `min_points_per_leaf` is not specified, this is None.
129
+ 'opt' : LeafDict
130
+ A dictionary with config values:
131
+
132
+ 'suffstat_batch_size' : int or None
133
+ The batch size for computing sufficient statistics.
134
+ 'small_float' : dtype
135
+ The dtype for large arrays used in the algorithm.
136
+ 'large_float' : dtype
137
+ The dtype for scalars, small arrays, and arrays which require
138
+ accuracy.
139
+ 'require_min_points' : bool
140
+ Whether the `min_points_per_leaf` parameter is specified.
124
141
  """
125
142
 
126
143
  p_nonterminal = jnp.asarray(p_nonterminal, large_float)
127
- max_depth = p_nonterminal.size + 1
144
+ p_nonterminal = jnp.pad(p_nonterminal, (0, 1))
145
+ max_depth = p_nonterminal.size
128
146
 
129
147
  @functools.partial(jax.vmap, in_axes=None, out_axes=0, axis_size=num_trees)
130
148
  def make_forest(max_depth, dtype):
131
149
  return grove.make_tree(max_depth, dtype)
132
150
 
151
+ small_float = jnp.dtype(small_float)
152
+ large_float = jnp.dtype(large_float)
153
+ y = jnp.asarray(y, small_float)
154
+ suffstat_batch_size = _choose_suffstat_batch_size(suffstat_batch_size, y)
155
+
133
156
  bart = dict(
134
157
  leaf_trees=make_forest(max_depth, small_float),
135
- var_trees=make_forest(max_depth - 1, grove.minimal_unsigned_dtype(X.shape[0] - 1)),
158
+ var_trees=make_forest(max_depth - 1, jaxext.minimal_unsigned_dtype(X.shape[0] - 1)),
136
159
  split_trees=make_forest(max_depth - 1, max_split.dtype),
137
160
  resid=jnp.asarray(y, large_float),
138
161
  sigma2=jnp.ones((), large_float),
@@ -143,9 +166,9 @@ def init(*,
143
166
  p_nonterminal=p_nonterminal,
144
167
  sigma2_alpha=jnp.asarray(sigma2_alpha, large_float),
145
168
  sigma2_beta=jnp.asarray(sigma2_beta, large_float),
146
- max_split=max_split,
147
- y=jnp.asarray(y, small_float),
148
- X=X,
169
+ max_split=jnp.asarray(max_split),
170
+ y=y,
171
+ X=jnp.asarray(X),
149
172
  min_points_per_leaf=(
150
173
  None if min_points_per_leaf is None else
151
174
  jnp.asarray(min_points_per_leaf)
@@ -154,10 +177,39 @@ def init(*,
154
177
  None if min_points_per_leaf is None else
155
178
  make_forest(max_depth - 1, bool).at[:, 1].set(y.size >= 2 * min_points_per_leaf)
156
179
  ),
180
+ opt=jaxext.LeafDict(
181
+ suffstat_batch_size=suffstat_batch_size,
182
+ small_float=small_float,
183
+ large_float=large_float,
184
+ require_min_points=min_points_per_leaf is not None,
185
+ ),
157
186
  )
158
187
 
159
188
  return bart
160
189
 
190
+ def _choose_suffstat_batch_size(size, y):
191
+ if size == 'auto':
192
+ try:
193
+ device = y.devices().pop()
194
+ except jax.errors.ConcretizationTypeError:
195
+ device = jax.devices()[0]
196
+ platform = device.platform
197
+
198
+ if platform == 'cpu':
199
+ return None
200
+ # maybe I should batch residuals (not counts) for numerical
201
+ # accuracy, even if it's slower
202
+ elif platform == 'gpu':
203
+ return 128 # 128 is good on A100, and V100 at high n
204
+ # 512 is good on T4, and V100 at low n
205
+ else:
206
+ raise KeyError(f'Unknown platform: {platform}')
207
+
208
+ elif size is not None:
209
+ return int(size)
210
+
211
+ return size
212
+
161
213
  def step(bart, key):
162
214
  """
163
215
  Perform one full MCMC step on a BART state.
@@ -196,11 +248,14 @@ def sample_trees(bart, key):
196
248
 
197
249
  Notes
198
250
  -----
199
- This function zeroes the proposal counters.
251
+ This function zeroes the proposal counters before using them.
200
252
  """
253
+ bart = bart.copy()
201
254
  key, subkey = random.split(key)
202
255
  grow_moves, prune_moves = sample_moves(bart, subkey)
203
- return accept_moves_and_sample_leaves(bart, grow_moves, prune_moves, key)
256
+ bart['var_trees'] = grow_moves['var_tree']
257
+ grow_leaf_indices = grove.traverse_forest(bart['X'], grow_moves['var_tree'], grow_moves['split_tree'])
258
+ return accept_moves_and_sample_leaves(bart, grow_moves, prune_moves, grow_leaf_indices, key)
204
259
 
205
260
  def sample_moves(bart, key):
206
261
  """
@@ -216,20 +271,7 @@ def sample_moves(bart, key):
216
271
  Returns
217
272
  -------
218
273
  grow_moves, prune_moves : dict
219
- The proposals for grow and prune moves, with these fields:
220
-
221
- 'allowed' : bool array (num_trees,)
222
- Whether the move is possible.
223
- 'node' : int array (num_trees,)
224
- The index of the leaf to grow or node to prune.
225
- 'var_tree' : int array (num_trees, 2 ** (d - 1),)
226
- The new decision axes of the tree.
227
- 'split_tree' : int array (num_trees, 2 ** (d - 1),)
228
- The new decision boundaries of the tree.
229
- 'partial_ratio' : float array (num_trees,)
230
- A factor of the Metropolis-Hastings ratio of the move. It lacks
231
- the likelihood ratio, and the probability of proposing the prune
232
- move. For the prune move, the ratio is inverted.
274
+ The proposals for grow and prune moves. See `grow_move` and `prune_move`.
233
275
  """
234
276
  key = random.split(key, bart['var_trees'].shape[0])
235
277
  return sample_moves_vmap_trees(bart['var_trees'], bart['split_trees'], bart['affluence_trees'], bart['max_split'], bart['p_nonterminal'], key)
@@ -260,7 +302,7 @@ def grow_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, ke
260
302
  Whether a leaf has enough points to be grown.
261
303
  max_split : array (p,)
262
304
  The maximum split index for each variable.
263
- p_nonterminal : array (d - 1,)
305
+ p_nonterminal : array (d,)
264
306
  The probability of a nonterminal node at each depth.
265
307
  key : jax.dtypes.prng_key array
266
308
  A jax random key.
@@ -292,16 +334,16 @@ def grow_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, ke
292
334
  var_tree = var_tree.at[leaf_to_grow].set(var.astype(var_tree.dtype))
293
335
 
294
336
  split = choose_split(var_tree, split_tree, max_split, leaf_to_grow, key2)
295
- new_split_tree = split_tree.at[leaf_to_grow].set(split.astype(split_tree.dtype))
337
+ split_tree = split_tree.at[leaf_to_grow].set(split.astype(split_tree.dtype))
296
338
 
297
- ratio = compute_partial_ratio(num_growable, num_prunable, p_nonterminal, leaf_to_grow, split_tree, new_split_tree)
339
+ ratio = compute_partial_ratio(num_growable, num_prunable, p_nonterminal, leaf_to_grow, split_tree)
298
340
 
299
341
  return dict(
300
342
  allowed=allowed,
301
343
  node=leaf_to_grow,
302
- var_tree=var_tree,
303
- split_tree=new_split_tree,
304
344
  partial_ratio=ratio,
345
+ var_tree=var_tree,
346
+ split_tree=split_tree,
305
347
  )
306
348
 
307
349
  def choose_leaf(split_tree, affluence_tree, key):
@@ -464,7 +506,7 @@ def ancestor_variables(var_tree, max_split, node_index):
464
506
  the parent. Unused spots are filled with `p`.
465
507
  """
466
508
  max_num_ancestors = grove.tree_depth(var_tree) - 1
467
- ancestor_vars = jnp.zeros(max_num_ancestors, grove.minimal_unsigned_dtype(max_split.size))
509
+ ancestor_vars = jnp.zeros(max_num_ancestors, jaxext.minimal_unsigned_dtype(max_split.size))
468
510
  carry = ancestor_vars.size - 1, node_index, ancestor_vars
469
511
  def loop(carry, _):
470
512
  i, index, ancestor_vars = carry
@@ -569,7 +611,7 @@ def choose_split(var_tree, split_tree, max_split, leaf_index, key):
569
611
  l, r = split_range(var_tree, split_tree, max_split, leaf_index, var)
570
612
  return random.randint(key, (), l, r)
571
613
 
572
- def compute_partial_ratio(num_growable, num_prunable, p_nonterminal, leaf_to_grow, initial_split_tree, new_split_tree):
614
+ def compute_partial_ratio(num_growable, num_prunable, p_nonterminal, leaf_to_grow, new_split_tree):
573
615
  """
574
616
  Compute the product of the transition and prior ratios of a grow move.
575
617
 
@@ -580,12 +622,10 @@ def compute_partial_ratio(num_growable, num_prunable, p_nonterminal, leaf_to_gro
580
622
  num_prunable : int
581
623
  The number of leaf parents that could be pruned, after converting the
582
624
  leaf to be grown to a non-terminal node.
583
- p_nonterminal : array (d - 1,)
625
+ p_nonterminal : array (d,)
584
626
  The probability of a nonterminal node at each depth.
585
627
  leaf_to_grow : int
586
628
  The index of the leaf to grow.
587
- initial_split_tree : array (2 ** (d - 1),)
588
- The splitting points of the tree, before the leaf is grown.
589
629
  new_split_tree : array (2 ** (d - 1),)
590
630
  The splitting points of the tree, after the leaf is grown.
591
631
 
@@ -600,14 +640,18 @@ def compute_partial_ratio(num_growable, num_prunable, p_nonterminal, leaf_to_gro
600
640
  # the two ratios also contain factors num_available_split *
601
641
  # num_available_var, but they cancel out
602
642
 
603
- prune_was_allowed = prune_allowed(initial_split_tree)
604
- p_grow = jnp.where(prune_was_allowed, 0.5, 1)
643
+ prune_allowed = leaf_to_grow != 1
644
+ # prune allowed <---> the initial tree is not a root
645
+ # leaf to grow is root --> the tree can only be a root
646
+ # tree is a root --> the only leaf I can grow is root
647
+
648
+ p_grow = jnp.where(prune_allowed, 0.5, 1)
605
649
 
606
650
  trans_ratio = num_growable / (p_grow * num_prunable)
607
651
 
608
- depth = grove.tree_depths(initial_split_tree.size)[leaf_to_grow]
652
+ depth = grove.tree_depths(new_split_tree.size)[leaf_to_grow]
609
653
  p_parent = p_nonterminal[depth]
610
- cp_children = 1 - p_nonterminal.at[depth + 1].get(mode='fill', fill_value=0)
654
+ cp_children = 1 - p_nonterminal[depth + 1]
611
655
  tree_ratio = cp_children * cp_children * p_parent / (1 - p_parent)
612
656
 
613
657
  return trans_ratio * tree_ratio
@@ -626,7 +670,7 @@ def prune_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, k
626
670
  Whether a leaf has enough points to be grown.
627
671
  max_split : array (p,)
628
672
  The maximum split index for each variable.
629
- p_nonterminal : array (d - 1,)
673
+ p_nonterminal : array (d,)
630
674
  The probability of a nonterminal node at each depth.
631
675
  key : jax.dtypes.prng_key array
632
676
  A jax random key.
@@ -639,28 +683,20 @@ def prune_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, k
639
683
  'allowed' : bool
640
684
  Whether the move is possible.
641
685
  'node' : int
642
- The index of the leaf to grow.
643
- 'var_tree' : array (2 ** (d - 1),)
644
- The new decision axes of the tree.
645
- 'split_tree' : array (2 ** (d - 1),)
646
- The new decision boundaries of the tree.
686
+ The index of the node to prune.
647
687
  'partial_ratio' : float
648
688
  A factor of the Metropolis-Hastings ratio of the move. It lacks
649
689
  the likelihood ratio and the probability of proposing the prune
650
690
  move. This ratio is inverted.
651
691
  """
652
692
  node_to_prune, num_prunable, num_growable = choose_leaf_parent(split_tree, affluence_tree, key)
653
- allowed = prune_allowed(split_tree)
654
-
655
- new_split_tree = split_tree.at[node_to_prune].set(0)
693
+ allowed = split_tree[1].astype(bool) # allowed iff the tree is not a root
656
694
 
657
- ratio = compute_partial_ratio(num_growable, num_prunable, p_nonterminal, node_to_prune, new_split_tree, split_tree)
695
+ ratio = compute_partial_ratio(num_growable, num_prunable, p_nonterminal, node_to_prune, split_tree)
658
696
 
659
697
  return dict(
660
698
  allowed=allowed,
661
699
  node=node_to_prune,
662
- var_tree=var_tree,
663
- split_tree=new_split_tree,
664
700
  partial_ratio=ratio, # it is inverted in accept_move_and_sample_leaves
665
701
  )
666
702
 
@@ -702,29 +738,37 @@ def choose_leaf_parent(split_tree, affluence_tree, key):
702
738
 
703
739
  return node_to_prune, num_prunable, num_growable
704
740
 
705
- def prune_allowed(split_tree):
741
+ def accept_moves_and_sample_leaves(bart, grow_moves, prune_moves, grow_leaf_indices, key):
706
742
  """
707
- Return whether a prune move is allowed.
743
+ Accept or reject the proposed moves and sample the new leaf values.
708
744
 
709
745
  Parameters
710
746
  ----------
711
- split_tree : array (2 ** (d - 1),)
712
- The splitting points of the tree.
747
+ bart : dict
748
+ A BART mcmc state.
749
+ grow_moves : dict
750
+ The proposals for grow moves, batched over the first axis. See
751
+ `grow_move`.
752
+ prune_moves : dict
753
+ The proposals for prune moves, batched over the first axis. See
754
+ `prune_move`.
755
+ grow_leaf_indices : int array (num_trees, n)
756
+ The leaf indices of the trees proposed by the grow move.
757
+ key : jax.dtypes.prng_key array
758
+ A jax random key.
713
759
 
714
760
  Returns
715
761
  -------
716
- allowed : bool
717
- Whether a prune move is allowed.
762
+ bart : dict
763
+ The new BART mcmc state.
718
764
  """
719
- return split_tree.at[1].get(mode='fill', fill_value=0).astype(bool)
720
-
721
- def accept_moves_and_sample_leaves(bart, grow_moves, prune_moves, key):
722
765
  bart = bart.copy()
723
766
  def loop(carry, item):
724
767
  resid = carry.pop('resid')
725
768
  resid, carry, trees = accept_move_and_sample_leaves(
726
769
  bart['X'],
727
770
  len(bart['leaf_trees']),
771
+ bart['opt']['suffstat_batch_size'],
728
772
  resid,
729
773
  bart['sigma2'],
730
774
  bart['min_points_per_leaf'],
@@ -740,11 +784,11 @@ def accept_moves_and_sample_leaves(bart, grow_moves, prune_moves, key):
740
784
  carry['resid'] = bart['resid']
741
785
  items = (
742
786
  bart['leaf_trees'],
743
- bart['var_trees'],
744
787
  bart['split_trees'],
745
788
  bart['affluence_trees'],
746
789
  grow_moves,
747
790
  prune_moves,
791
+ grow_leaf_indices,
748
792
  random.split(key, len(bart['leaf_trees'])),
749
793
  )
750
794
  carry, trees = lax.scan(loop, carry, items)
@@ -752,11 +796,50 @@ def accept_moves_and_sample_leaves(bart, grow_moves, prune_moves, key):
752
796
  bart.update(trees)
753
797
  return bart
754
798
 
755
- def accept_move_and_sample_leaves(X, ntree, resid, sigma2, min_points_per_leaf, counts, leaf_tree, var_tree, split_tree, affluence_tree, grow_move, prune_move, key):
756
-
757
- # compute leaf indices according to grow move tree
758
- traverse_tree = jax.vmap(grove.traverse_tree, in_axes=(1, None, None))
759
- grow_leaf_indices = traverse_tree(X, grow_move['var_tree'], grow_move['split_tree'])
799
+ def accept_move_and_sample_leaves(X, ntree, suffstat_batch_size, resid, sigma2, min_points_per_leaf, counts, leaf_tree, split_tree, affluence_tree, grow_move, prune_move, grow_leaf_indices, key):
800
+ """
801
+ Accept or reject a proposed move and sample the new leaf values.
802
+
803
+ Parameters
804
+ ----------
805
+ X : int array (p, n)
806
+ The predictors.
807
+ ntree : int
808
+ The number of trees in the forest.
809
+ suffstat_batch_size : int, None
810
+ The batch size for computing sufficient statistics.
811
+ resid : float array (n,)
812
+ The residuals (data minus forest value).
813
+ sigma2 : float
814
+ The noise variance.
815
+ min_points_per_leaf : int or None
816
+ The minimum number of data points in a leaf node.
817
+ counts : dict
818
+ The acceptance counts from the mcmc state dict.
819
+ leaf_tree : float array (2 ** d,)
820
+ The leaf values of the tree.
821
+ split_tree : int array (2 ** (d - 1),)
822
+ The decision boundaries of the tree.
823
+ affluence_tree : bool array (2 ** (d - 1),) or None
824
+ Whether a leaf has enough points to be grown.
825
+ grow_move : dict
826
+ The proposal for the grow move. See `grow_move`.
827
+ prune_move : dict
828
+ The proposal for the prune move. See `prune_move`.
829
+ grow_leaf_indices : int array (n,)
830
+ The leaf indices of the tree proposed by the grow move.
831
+ key : jax.dtypes.prng_key array
832
+ A jax random key.
833
+
834
+ Returns
835
+ -------
836
+ resid : float array (n,)
837
+ The updated residuals (data minus forest value).
838
+ counts : dict
839
+ The updated acceptance counts.
840
+ trees : dict
841
+ The updated tree arrays.
842
+ """
760
843
 
761
844
  # compute leaf indices in starting tree
762
845
  grow_node = grow_move['node']
@@ -782,10 +865,7 @@ def accept_move_and_sample_leaves(X, ntree, resid, sigma2, min_points_per_leaf,
782
865
  resid += leaf_tree[leaf_indices]
783
866
 
784
867
  # aggregate residuals and count units per leaf
785
- grow_resid_tree = jnp.zeros_like(leaf_tree, sigma2.dtype)
786
- grow_resid_tree = grow_resid_tree.at[grow_leaf_indices].add(resid)
787
- grow_count_tree = jnp.zeros_like(leaf_tree, grove.minimal_unsigned_dtype(resid.size))
788
- grow_count_tree = grow_count_tree.at[grow_leaf_indices].add(1)
868
+ grow_resid_tree, grow_count_tree = sufficient_stat(resid, grow_leaf_indices, leaf_tree.size, suffstat_batch_size)
789
869
 
790
870
  # compute aggregations in starting tree
791
871
  # I do not zero the children because garbage there does not matter
@@ -833,10 +913,10 @@ def accept_move_and_sample_leaves(X, ntree, resid, sigma2, min_points_per_leaf,
833
913
 
834
914
  # pick trees for chosen move
835
915
  trees = {}
836
- var_tree = jnp.where(do_grow, grow_move['var_tree'], var_tree)
837
916
  split_tree = jnp.where(do_grow, grow_move['split_tree'], split_tree)
838
- var_tree = jnp.where(do_prune, prune_move['var_tree'], var_tree)
839
- split_tree = jnp.where(do_prune, prune_move['split_tree'], split_tree)
917
+ # the prune var tree is equal to the initial one, because I leave garbage values behind
918
+ split_tree = split_tree.at[prune_node].set(
919
+ jnp.where(do_prune, 0, split_tree[prune_node]))
840
920
  if min_points_per_leaf is not None:
841
921
  affluence_tree = jnp.where(do_grow, grow_affluence_tree, affluence_tree)
842
922
  affluence_tree = jnp.where(do_prune, prune_affluence_tree, affluence_tree)
@@ -869,13 +949,60 @@ def accept_move_and_sample_leaves(X, ntree, resid, sigma2, min_points_per_leaf,
869
949
  # pack trees
870
950
  trees = {
871
951
  'leaf_trees': leaf_tree,
872
- 'var_trees': var_tree,
873
952
  'split_trees': split_tree,
874
953
  'affluence_trees': affluence_tree,
875
954
  }
876
955
 
877
956
  return resid, counts, trees
878
957
 
958
+ def sufficient_stat(resid, leaf_indices, tree_size, batch_size):
959
+ """
960
+ Compute the sufficient statistics for the likelihood ratio of a tree move.
961
+
962
+ Parameters
963
+ ----------
964
+ resid : float array (n,)
965
+ The residuals (data minus forest value).
966
+ leaf_indices : int array (n,)
967
+ The leaf indices of the tree (in which leaf each data point falls into).
968
+ tree_size : int
969
+ The size of the tree array (2 ** d).
970
+ batch_size : int, None
971
+ The batch size for the aggregation. Batching increases numerical
972
+ accuracy and parallelism.
973
+
974
+ Returns
975
+ -------
976
+ resid_tree : float array (2 ** d,)
977
+ The sum of the residuals at data points in each leaf.
978
+ count_tree : int array (2 ** d,)
979
+ The number of data points in each leaf.
980
+ """
981
+ if batch_size is None:
982
+ aggr_func = _aggregate_scatter
983
+ else:
984
+ aggr_func = functools.partial(_aggregate_batched, batch_size=batch_size)
985
+ resid_tree = aggr_func(resid, leaf_indices, tree_size, jnp.float32)
986
+ count_tree = aggr_func(1, leaf_indices, tree_size, jnp.uint32)
987
+ return resid_tree, count_tree
988
+
989
+ def _aggregate_scatter(values, indices, size, dtype):
990
+ return (jnp
991
+ .zeros(size, dtype)
992
+ .at[indices]
993
+ .add(values)
994
+ )
995
+
996
+ def _aggregate_batched(values, indices, size, dtype, batch_size):
997
+ nbatches = indices.size // batch_size + bool(indices.size % batch_size)
998
+ batch_indices = jnp.arange(indices.size) // batch_size
999
+ return (jnp
1000
+ .zeros((nbatches, size), dtype)
1001
+ .at[batch_indices, indices]
1002
+ .add(values)
1003
+ .sum(axis=0)
1004
+ )
1005
+
879
1006
  def compute_p_prune_back(new_split_tree, new_affluence_tree):
880
1007
  """
881
1008
  Compute the probability of proposing a prune move after doing a grow move.
bartz/prepcovars.py CHANGED
@@ -27,8 +27,10 @@ import functools
27
27
  import jax
28
28
  from jax import numpy as jnp
29
29
 
30
+ from . import jaxext
30
31
  from . import grove
31
32
 
33
+ @functools.partial(jax.jit, static_argnums=(1,))
32
34
  def quantilized_splits_from_matrix(X, max_bins):
33
35
  """
34
36
  Determine bins that make the distribution of each predictor uniform.
@@ -52,48 +54,41 @@ def quantilized_splits_from_matrix(X, max_bins):
52
54
  The number of actually used values in each row of `splits`.
53
55
  """
54
56
  out_length = min(max_bins, X.shape[1]) - 1
55
- return quantilized_splits_from_matrix_impl(X, out_length)
57
+ # return _quantilized_splits_from_matrix(X, out_length)
58
+ @functools.partial(jaxext.autobatch, max_io_nbytes=500_000_000)
59
+ def func(X):
60
+ return _quantilized_splits_from_matrix(X, out_length)
61
+ return func(X)
56
62
 
57
63
  @functools.partial(jax.vmap, in_axes=(0, None))
58
- def quantilized_splits_from_matrix_impl(x, out_length):
59
- huge = huge_value(x)
60
- u = jnp.unique(x, size=x.size, fill_value=huge)
61
- actual_length = jnp.count_nonzero(u < huge) - 1
62
- midpoints = (u[1:] + u[:-1]) / 2
64
+ def _quantilized_splits_from_matrix(x, out_length):
65
+ huge = jaxext.huge_value(x)
66
+ u, actual_length = jaxext.unique(x, size=x.size, fill_value=huge)
67
+ actual_length -= 1
68
+ if jnp.issubdtype(x.dtype, jnp.integer):
69
+ midpoints = u[:-1] + jaxext.ensure_unsigned(u[1:] - u[:-1]) // 2
70
+ indices = jnp.arange(midpoints.size, dtype=jaxext.minimal_unsigned_dtype(midpoints.size - 1))
71
+ midpoints = jnp.where(indices < actual_length, midpoints, huge)
72
+ else:
73
+ midpoints = (u[1:] + u[:-1]) / 2
63
74
  indices = jnp.linspace(-1, actual_length, out_length + 2)[1:-1]
64
- indices = jnp.around(indices).astype(grove.minimal_unsigned_dtype(midpoints.size - 1))
75
+ indices = jnp.around(indices).astype(jaxext.minimal_unsigned_dtype(midpoints.size - 1))
65
76
  # indices calculation with float rather than int to avoid potential
66
77
  # overflow with int32, and to round to nearest instead of rounding down
67
78
  decimated_midpoints = midpoints[indices]
68
79
  truncated_midpoints = midpoints[:out_length]
69
80
  splits = jnp.where(actual_length > out_length, decimated_midpoints, truncated_midpoints)
70
81
  max_split = jnp.minimum(actual_length, out_length)
71
- max_split = max_split.astype(grove.minimal_unsigned_dtype(out_length))
82
+ max_split = max_split.astype(jaxext.minimal_unsigned_dtype(out_length))
72
83
  return splits, max_split
73
84
 
74
- def huge_value(x):
75
- """
76
- Return the maximum value that can be stored in `x`.
77
-
78
- Parameters
79
- ----------
80
- x : array
81
- A numerical numpy or jax array.
82
-
83
- Returns
84
- -------
85
- maxval : scalar
86
- The maximum value allowed by `x`'s type (+inf for floats).
87
- """
88
- if jnp.issubdtype(x.dtype, jnp.integer):
89
- return jnp.iinfo(x.dtype).max
90
- else:
91
- return jnp.inf
92
-
85
+ @jax.jit
93
86
  def bin_predictors(X, splits):
94
87
  """
95
88
  Bin the predictors according to the given splits.
96
89
 
90
+ A value ``x`` is mapped to bin ``i`` iff ``splits[i - 1] < x <= splits[i]``.
91
+
97
92
  Parameters
98
93
  ----------
99
94
  X : array (p, n)
@@ -110,9 +105,9 @@ def bin_predictors(X, splits):
110
105
  A matrix with `p` predictors and `n` observations, where each predictor
111
106
  has been replaced by the index of the bin it falls into.
112
107
  """
113
- return bin_predictors_impl(X, splits)
108
+ return _bin_predictors(X, splits)
114
109
 
115
110
  @jax.vmap
116
- def bin_predictors_impl(x, splits):
117
- dtype = grove.minimal_unsigned_dtype(splits.size)
111
+ def _bin_predictors(x, splits):
112
+ dtype = jaxext.minimal_unsigned_dtype(splits.size)
118
113
  return jnp.searchsorted(splits, x).astype(dtype)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: bartz
3
- Version: 0.1.0
3
+ Version: 0.2.1
4
4
  Summary: A JAX implementation of BART
5
5
  Home-page: https://github.com/Gattocrucco/bartz
6
6
  License: MIT
@@ -20,7 +20,13 @@ Project-URL: Bug Tracker, https://github.com/Gattocrucco/bartz/issues
20
20
  Project-URL: Repository, https://github.com/Gattocrucco/bartz
21
21
  Description-Content-Type: text/markdown
22
22
 
23
+ [![PyPI](https://img.shields.io/pypi/v/bartz)](https://pypi.org/project/bartz/)
24
+
23
25
  # BART vectoriZed
24
26
 
25
27
  A branchless vectorized implementation of Bayesian Additive Regression Trees (BART) in JAX.
26
28
 
29
+ BART is a nonparametric Bayesian regression technique. Given predictors $X$ and responses $y$, BART finds a function to predict $y$ given $X$. The result of the inference is a sample of possible functions, representing the uncertainty over the determination of the function.
30
+
31
+ This Python module provides an implementation of BART that runs on GPU, to process large datasets faster. It is also a good on CPU. Most other implementations of BART are for R, and run on CPU only.
32
+
@@ -0,0 +1,13 @@
1
+ bartz/BART.py,sha256=pRG7mALenknX2JHqY-VyhO9-evDgEC6hWBp4jpecBdM,15801
2
+ bartz/__init__.py,sha256=E96vsP0bZ8brejpZmEmRoXuMsUdinO_B_SKUUl1rLsg,1448
3
+ bartz/_version.py,sha256=PmcQ2PI2oP8irnLtJLJby2YfW6sBvLAmL-VpABzTqwc,22
4
+ bartz/debug.py,sha256=9ZH-JfwZVu5OPhHBEyXQHAU5H9KIu1vxLK7yNv4m4Ew,5314
5
+ bartz/grove.py,sha256=Wj_7jHl9w3uwuVdH4hoeXowimGpdRE2lGIzr4aDkzsI,8291
6
+ bartz/jaxext.py,sha256=VYA41D5F7DYcAAVtkcZtEN927HxQGOOQM-uGsgr2CPc,10996
7
+ bartz/mcmcloop.py,sha256=lheLrjVxmlyQzc_92zeNsFhdkrhEWQEjoAWFbVzknnw,7701
8
+ bartz/mcmcstep.py,sha256=6fzNMumXjMe6Fj6zoHLTf1D42JuAiQyGHfr6l1Bwrnk,39450
9
+ bartz/prepcovars.py,sha256=iiQ0WjSj4--l5DgPW626Qg2SSB6ljnaaUsBz_A8kFrI,4634
10
+ bartz-0.2.1.dist-info/LICENSE,sha256=heuIJZQK9IexJYC-fYHoLUrgj8HG8yS3G072EvKh-94,1073
11
+ bartz-0.2.1.dist-info/METADATA,sha256=eGxicC1iR-Bpjk1uKn50g6FxdFfq9S70nl7m5GmXO14,1490
12
+ bartz-0.2.1.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
13
+ bartz-0.2.1.dist-info/RECORD,,
@@ -1,13 +0,0 @@
1
- bartz/__init__.py,sha256=40tX5XHoTiGnZcoeogVpyNOM_5rbHt-Y6zTI0NS7OA4,1345
2
- bartz/_version.py,sha256=IMjkMO3twhQzluVTo8Z6rE7Eg-9U79_LGKMcsWLKBkY,22
3
- bartz/debug.py,sha256=_HOjDieipAgliP6B6C0UMgz-mVgmeZ3zmtzVe-iMGtY,5289
4
- bartz/grove.py,sha256=LHhnvNKLb-jxUf4YjP927Hf9txkXynhMZ2ejtMRWZl4,8353
5
- bartz/interface.py,sha256=INyNuHzFySwXAsXofVZDpTsMv78AR_3VCvAHbZFh92c,15724
6
- bartz/jaxext.py,sha256=FK5j1zfW1yR4-yPKcD7ZvKSkVQ5--jHjQpVCl4n4gXY,2844
7
- bartz/mcmcloop.py,sha256=xTxC1AkNX8jCrMArblvlMjnjMh80q1M3a6ZGrDdfsFI,7423
8
- bartz/mcmcstep.py,sha256=6zkpTqgIrapeVy9mhy6BlsIO0s26HwBRDfw_6dVMmZA,35207
9
- bartz/prepcovars.py,sha256=3ddDOtNNop3Ba2Kgy_dZ6apFydtwaEXH3uXSmmKf9Fs,4421
10
- bartz-0.1.0.dist-info/LICENSE,sha256=heuIJZQK9IexJYC-fYHoLUrgj8HG8yS3G072EvKh-94,1073
11
- bartz-0.1.0.dist-info/METADATA,sha256=8YYlbCf7frDtT2of6tNlnBbuGqyO8YyYlED8OXSiBpA,933
12
- bartz-0.1.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
13
- bartz-0.1.0.dist-info/RECORD,,
File without changes
File without changes