bartz 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,4 @@
1
- # bartz/src/bartz/interface.py
1
+ # bartz/src/bartz/BART.py
2
2
  #
3
3
  # Copyright (c) 2024, Giacomo Petrillo
4
4
  #
@@ -33,7 +33,7 @@ from . import mcmcstep
33
33
  from . import mcmcloop
34
34
  from . import prepcovars
35
35
 
36
- class BART:
36
+ class gbart:
37
37
  """
38
38
  Nonparametric regression with Bayesian Additive Regression Trees (BART).
39
39
 
@@ -133,7 +133,7 @@ class BART:
133
133
 
134
134
  Notes
135
135
  -----
136
- This interface imitates the function `wbart` from the R package `BART
136
+ This interface imitates the function `gbart` from the R package `BART
137
137
  <https://cran.r-project.org/package=BART>`_, but with these differences:
138
138
 
139
139
  - If `x_train` and `x_test` are matrices, they have one predictor per row
@@ -142,6 +142,7 @@ class BART:
142
142
  - `usequants` is always `True`.
143
143
  - `rm_const` is always `False`.
144
144
  - The default `numcut` is 255 instead of 100.
145
+ - A lot of functionality is missing (variable selection, discrete response).
145
146
  - There are some additional attributes, and some missing.
146
147
  """
147
148
 
bartz/__init__.py CHANGED
@@ -30,6 +30,11 @@ See the manual at https://gattocrucco.github.io/bartz/docs
30
30
 
31
31
  from ._version import __version__
32
32
 
33
- from .interface import BART
33
+ from . import BART
34
34
 
35
35
  from . import debug
36
+ from . import grove
37
+ from . import mcmcstep
38
+ from . import mcmcloop
39
+ from . import prepcovars
40
+ from . import jaxext
bartz/_version.py CHANGED
@@ -1 +1 @@
1
- __version__ = '0.1.0'
1
+ __version__ = '0.2.0'
bartz/debug.py CHANGED
@@ -6,6 +6,7 @@ from jax import lax
6
6
 
7
7
  from . import grove
8
8
  from . import mcmcstep
9
+ from . import jaxext
9
10
 
10
11
  def print_tree(leaf_tree, var_tree, split_tree, print_all=False):
11
12
 
@@ -83,7 +84,7 @@ def trace_depth_distr(split_trees_trace):
83
84
  def points_per_leaf_distr(var_tree, split_tree, X):
84
85
  traverse_tree = jax.vmap(grove.traverse_tree, in_axes=(1, None, None))
85
86
  indices = traverse_tree(X, var_tree, split_tree)
86
- count_tree = jnp.zeros(2 * split_tree.size, dtype=grove.minimal_unsigned_dtype(indices.size))
87
+ count_tree = jnp.zeros(2 * split_tree.size, dtype=jaxext.minimal_unsigned_dtype(indices.size))
87
88
  count_tree = count_tree.at[indices].add(1)
88
89
  is_leaf = grove.is_actual_leaf(split_tree, add_bottom_level=True).view(jnp.uint8)
89
90
  return jnp.bincount(count_tree, is_leaf, length=X.shape[1] + 1)
@@ -103,7 +104,7 @@ def trace_points_per_leaf_distr(bart, X):
103
104
  return distr
104
105
 
105
106
  def check_types(leaf_tree, var_tree, split_tree, max_split):
106
- expected_var_dtype = grove.minimal_unsigned_dtype(max_split.size - 1)
107
+ expected_var_dtype = jaxext.minimal_unsigned_dtype(max_split.size - 1)
107
108
  expected_split_dtype = max_split.dtype
108
109
  return var_tree.dtype == expected_var_dtype and split_tree.dtype == expected_split_dtype
109
110
 
@@ -117,7 +118,7 @@ def check_leaf_values(leaf_tree, var_tree, split_tree, max_split):
117
118
  return jnp.all(jnp.isfinite(leaf_tree))
118
119
 
119
120
  def check_stray_nodes(leaf_tree, var_tree, split_tree, max_split):
120
- index = jnp.arange(2 * split_tree.size, dtype=grove.minimal_unsigned_dtype(2 * split_tree.size - 1))
121
+ index = jnp.arange(2 * split_tree.size, dtype=jaxext.minimal_unsigned_dtype(2 * split_tree.size - 1))
121
122
  parent_index = index >> 1
122
123
  is_not_leaf = split_tree.at[index].get(mode='fill', fill_value=0) != 0
123
124
  parent_is_leaf = split_tree[parent_index] == 0
@@ -134,7 +135,7 @@ check_functions = [
134
135
  ]
135
136
 
136
137
  def check_tree(leaf_tree, var_tree, split_tree, max_split):
137
- error_type = grove.minimal_unsigned_dtype(2 ** len(check_functions) - 1)
138
+ error_type = jaxext.minimal_unsigned_dtype(2 ** len(check_functions) - 1)
138
139
  error = error_type(0)
139
140
  for i, func in enumerate(check_functions):
140
141
  ok = func(leaf_tree, var_tree, split_tree, max_split)
bartz/grove.py CHANGED
@@ -44,7 +44,6 @@ import functools
44
44
  import math
45
45
 
46
46
  import jax
47
-
48
47
  from jax import numpy as jnp
49
48
  from jax import lax
50
49
 
@@ -107,29 +106,47 @@ def traverse_tree(x, var_tree, split_tree):
107
106
 
108
107
  carry = (
109
108
  jnp.zeros((), bool),
110
- jnp.ones((), minimal_unsigned_dtype(2 * var_tree.size - 1)),
109
+ jnp.ones((), jaxext.minimal_unsigned_dtype(2 * var_tree.size - 1)),
111
110
  )
112
111
 
113
112
  def loop(carry, _):
114
113
  leaf_found, index = carry
115
114
 
116
- split = split_tree.at[index].get(mode='fill', fill_value=0)
117
- var = var_tree.at[index].get(mode='fill', fill_value=0)
115
+ split = split_tree[index]
116
+ var = var_tree[index]
118
117
 
119
- leaf_found |= split_tree.at[index].get(mode='fill', fill_value=0) == 0
118
+ leaf_found |= split == 0
120
119
  child_index = (index << 1) + (x[var] >= split)
121
120
  index = jnp.where(leaf_found, index, child_index)
122
121
 
123
122
  return (leaf_found, index), None
124
123
 
125
- # TODO
126
- # - unroll (how much? 5?)
127
- # - separate and special-case the last iteration
128
-
129
- depth = 1 + tree_depth(var_tree)
130
- (_, index), _ = lax.scan(loop, carry, None, depth)
124
+ depth = tree_depth(var_tree)
125
+ (_, index), _ = lax.scan(loop, carry, None, depth, unroll=16)
131
126
  return index
132
127
 
128
+ @functools.partial(jaxext.vmap_nodoc, in_axes=(None, 0, 0))
129
+ @functools.partial(jaxext.vmap_nodoc, in_axes=(1, None, None))
130
+ def traverse_forest(X, var_trees, split_trees):
131
+ """
132
+ Find the leaves where points fall into.
133
+
134
+ Parameters
135
+ ----------
136
+ X : array (p, n)
137
+ The coordinates to evaluate the trees at.
138
+ var_trees : array (m, 2 ** (d - 1))
139
+ The decision axes of the trees.
140
+ split_trees : array (m, 2 ** (d - 1))
141
+ The decision boundaries of the trees.
142
+
143
+ Returns
144
+ -------
145
+ indices : array (m, n)
146
+ The indices of the leaves.
147
+ """
148
+ return traverse_tree(X, var_trees, split_trees)
149
+
133
150
  def evaluate_forest(X, leaf_trees, var_trees, split_trees, dtype):
134
151
  """
135
152
  Evaluate a ensemble of trees at an array of points.
@@ -138,7 +155,7 @@ def evaluate_forest(X, leaf_trees, var_trees, split_trees, dtype):
138
155
  ----------
139
156
  X : array (p, n)
140
157
  The coordinates to evaluate the trees at.
141
- leaf_trees : (m, 2 ** d)
158
+ leaf_trees : array (m, 2 ** d)
142
159
  The leaf values of the tree or forest. If the input is a forest, the
143
160
  first axis is the tree index, and the values are summed.
144
161
  var_trees : array (m, 2 ** (d - 1))
@@ -153,30 +170,13 @@ def evaluate_forest(X, leaf_trees, var_trees, split_trees, dtype):
153
170
  out : array (n,)
154
171
  The sum of the values of the trees at the points in `X`.
155
172
  """
156
- indices = _traverse_forest(X, var_trees, split_trees)
173
+ indices = traverse_forest(X, var_trees, split_trees)
157
174
  ntree, _ = leaf_trees.shape
158
- tree_index = jnp.arange(ntree, dtype=minimal_unsigned_dtype(ntree - 1))[:, None]
175
+ tree_index = jnp.arange(ntree, dtype=jaxext.minimal_unsigned_dtype(ntree - 1))[:, None]
159
176
  leaves = leaf_trees[tree_index, indices]
160
177
  return jnp.sum(leaves, axis=0, dtype=dtype)
161
- # this sum suggests to swap the vmaps, but I think it's better for X copying to keep it that way
162
-
163
- @functools.partial(jax.vmap, in_axes=(None, 0, 0))
164
- @functools.partial(jax.vmap, in_axes=(1, None, None))
165
- def _traverse_forest(X, var_trees, split_trees):
166
- return traverse_tree(X, var_trees, split_trees)
167
-
168
- def minimal_unsigned_dtype(max_value):
169
- """
170
- Return the smallest unsigned integer dtype that can represent a given
171
- maximum value.
172
- """
173
- if max_value < 2 ** 8:
174
- return jnp.uint8
175
- if max_value < 2 ** 16:
176
- return jnp.uint16
177
- if max_value < 2 ** 32:
178
- return jnp.uint32
179
- return jnp.uint64
178
+ # this sum suggests to swap the vmaps, but I think it's better for X
179
+ # copying to keep it that way
180
180
 
181
181
  def is_actual_leaf(split_tree, *, add_bottom_level=False):
182
182
  """
@@ -200,7 +200,7 @@ def is_actual_leaf(split_tree, *, add_bottom_level=False):
200
200
  if add_bottom_level:
201
201
  size *= 2
202
202
  is_leaf = jnp.concatenate([is_leaf, jnp.ones_like(is_leaf)])
203
- index = jnp.arange(size, dtype=minimal_unsigned_dtype(size - 1))
203
+ index = jnp.arange(size, dtype=jaxext.minimal_unsigned_dtype(size - 1))
204
204
  parent_index = index >> 1
205
205
  parent_nonleaf = split_tree[parent_index].astype(bool)
206
206
  parent_nonleaf = parent_nonleaf.at[1].set(True)
@@ -220,7 +220,7 @@ def is_leaves_parent(split_tree):
220
220
  is_leaves_parent : bool array (2 ** (d - 1),)
221
221
  The mask indicating which nodes have leaf children.
222
222
  """
223
- index = jnp.arange(split_tree.size, dtype=minimal_unsigned_dtype(2 * split_tree.size - 1))
223
+ index = jnp.arange(split_tree.size, dtype=jaxext.minimal_unsigned_dtype(2 * split_tree.size - 1))
224
224
  left_index = index << 1 # left child
225
225
  right_index = left_index + 1 # right child
226
226
  left_leaf = split_tree.at[left_index].get(mode='fill', fill_value=0) == 0
@@ -252,4 +252,4 @@ def tree_depths(tree_length):
252
252
  depth += 1
253
253
  depths.append(depth - 1)
254
254
  depths[0] = 0
255
- return jnp.array(depths, minimal_unsigned_dtype(max(depths)))
255
+ return jnp.array(depths, jaxext.minimal_unsigned_dtype(max(depths)))
bartz/jaxext.py CHANGED
@@ -10,10 +10,10 @@
10
10
  # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11
11
  # copies of the Software, and to permit persons to whom the Software is
12
12
  # furnished to do so, subject to the following conditions:
13
- #
13
+ #
14
14
  # The above copyright notice and this permission notice shall be included in all
15
15
  # copies or substantial portions of the Software.
16
- #
16
+ #
17
17
  # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18
18
  # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19
19
  # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
@@ -23,12 +23,19 @@
23
23
  # SOFTWARE.
24
24
 
25
25
  import functools
26
+ import math
27
+ import warnings
26
28
 
27
29
  from scipy import special
28
30
  import jax
29
31
  from jax import numpy as jnp
32
+ from jax import tree_util
33
+ from jax import lax
30
34
 
31
35
  def float_type(*args):
36
+ """
37
+ Determine the jax floating point result type given operands/types.
38
+ """
32
39
  t = jnp.result_type(*args)
33
40
  return jnp.sin(jnp.empty(0, t)).dtype
34
41
 
@@ -39,8 +46,8 @@ def castto(func, type):
39
46
  return newfunc
40
47
 
41
48
  def pure_callback_ufunc(callback, dtype, *args, excluded=None, **kwargs):
42
- """ version of jax.pure_callback that deals correctly with ufuncs,
43
- see https://github.com/google/jax/issues/17187 """
49
+ """ version of `jax.pure_callback` that deals correctly with ufuncs,
50
+ see `<https://github.com/google/jax/issues/17187>`_ """
44
51
  if excluded is None:
45
52
  excluded = ()
46
53
  shape = jnp.broadcast_shapes(*(
@@ -63,6 +70,7 @@ class scipy:
63
70
 
64
71
  class special:
65
72
 
73
+ @functools.wraps(special.gammainccinv)
66
74
  def gammainccinv(a, y):
67
75
  a = jnp.asarray(a)
68
76
  y = jnp.asarray(y)
@@ -73,13 +81,261 @@ class scipy:
73
81
  class stats:
74
82
 
75
83
  class invgamma:
76
-
84
+
77
85
  def ppf(q, a):
78
86
  return 1 / scipy.special.gammainccinv(a, q)
79
87
 
80
88
  @functools.wraps(jax.vmap)
81
89
  def vmap_nodoc(fun, *args, **kw):
90
+ """
91
+ Version of `jax.vmap` that preserves the docstring of the input function.
92
+ """
82
93
  doc = fun.__doc__
83
94
  fun = jax.vmap(fun, *args, **kw)
84
95
  fun.__doc__ = doc
85
96
  return fun
97
+
98
+ def huge_value(x):
99
+ """
100
+ Return the maximum value that can be stored in `x`.
101
+
102
+ Parameters
103
+ ----------
104
+ x : array
105
+ A numerical numpy or jax array.
106
+
107
+ Returns
108
+ -------
109
+ maxval : scalar
110
+ The maximum value allowed by `x`'s type (+inf for floats).
111
+ """
112
+ if jnp.issubdtype(x.dtype, jnp.integer):
113
+ return jnp.iinfo(x.dtype).max
114
+ else:
115
+ return jnp.inf
116
+
117
+ def minimal_unsigned_dtype(max_value):
118
+ """
119
+ Return the smallest unsigned integer dtype that can represent a given
120
+ maximum value (inclusive).
121
+ """
122
+ if max_value < 2 ** 8:
123
+ return jnp.uint8
124
+ if max_value < 2 ** 16:
125
+ return jnp.uint16
126
+ if max_value < 2 ** 32:
127
+ return jnp.uint32
128
+ return jnp.uint64
129
+
130
+ def signed_to_unsigned(int_dtype):
131
+ """
132
+ Map a signed integer type to its unsigned counterpart. Unsigned types are
133
+ passed through.
134
+ """
135
+ assert jnp.issubdtype(int_dtype, jnp.integer)
136
+ if jnp.issubdtype(int_dtype, jnp.unsignedinteger):
137
+ return int_dtype
138
+ if int_dtype == jnp.int8:
139
+ return jnp.uint8
140
+ if int_dtype == jnp.int16:
141
+ return jnp.uint16
142
+ if int_dtype == jnp.int32:
143
+ return jnp.uint32
144
+ if int_dtype == jnp.int64:
145
+ return jnp.uint64
146
+
147
+ def ensure_unsigned(x):
148
+ """
149
+ If x has signed integer type, cast it to the unsigned dtype of the same size.
150
+ """
151
+ return x.astype(signed_to_unsigned(x.dtype))
152
+
153
+ @functools.partial(jax.jit, static_argnums=(1,))
154
+ def unique(x, size, fill_value):
155
+ """
156
+ Restricted version of `jax.numpy.unique` that uses less memory.
157
+
158
+ Parameters
159
+ ----------
160
+ x : 1d array
161
+ The input array.
162
+ size : int
163
+ The length of the output.
164
+ fill_value : scalar
165
+ The value to fill the output with if `size` is greater than the number
166
+ of unique values in `x`.
167
+
168
+ Returns
169
+ -------
170
+ out : array (size,)
171
+ The unique values in `x`, sorted, and right-padded with `fill_value`.
172
+ actual_length : int
173
+ The number of used values in `out`.
174
+ """
175
+ if x.size == 0:
176
+ return jnp.full(size, fill_value, x.dtype), 0
177
+ if size == 0:
178
+ return jnp.empty(0, x.dtype), 0
179
+ x = jnp.sort(x)
180
+ def loop(carry, x):
181
+ i_out, i_in, last, out = carry
182
+ i_out = jnp.where(x == last, i_out, i_out + 1)
183
+ out = out.at[i_out].set(x)
184
+ return (i_out, i_in + 1, x, out), None
185
+ carry = 0, 0, x[0], jnp.full(size, fill_value, x.dtype)
186
+ (actual_length, _, _, out), _ = jax.lax.scan(loop, carry, x[:size])
187
+ return out, actual_length + 1
188
+
189
+ def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False):
190
+ """
191
+ Batch a function such that each batch is smaller than a threshold.
192
+
193
+ Parameters
194
+ ----------
195
+ func : callable
196
+ A jittable function with positional arguments only, with inputs and
197
+ outputs pytrees of arrays.
198
+ max_io_nbytes : int
199
+ The maximum number of input + output bytes in each batch.
200
+ in_axes : pytree of ints, default 0
201
+ A tree matching the structure of the function input, indicating along
202
+ which axes each array should be batched. If a single integer, it is
203
+ used for all arrays.
204
+ out_axes : pytree of ints, default 0
205
+ The same for outputs.
206
+ return_nbatches : bool, default False
207
+ If True, the number of batches is returned as a second output.
208
+
209
+ Returns
210
+ -------
211
+ batched_func : callable
212
+ A function with the same signature as `func`, but that processes the
213
+ input and output in batches in a loop.
214
+ """
215
+
216
+ def expand_axes(axes, tree):
217
+ if isinstance(axes, int):
218
+ return tree_util.tree_map(lambda _: axes, tree)
219
+ return tree_util.tree_map(lambda _, axis: axis, tree, axes)
220
+
221
+ def extract_size(axes, tree):
222
+ sizes = tree_util.tree_map(lambda x, axis: x.shape[axis], tree, axes)
223
+ sizes, _ = tree_util.tree_flatten(sizes)
224
+ assert all(s == sizes[0] for s in sizes)
225
+ return sizes[0]
226
+
227
+ def sum_nbytes(tree):
228
+ def nbytes(x):
229
+ return math.prod(x.shape) * x.dtype.itemsize
230
+ return tree_util.tree_reduce(lambda size, x: size + nbytes(x), tree, 0)
231
+
232
+ def next_divisor_small(dividend, min_divisor):
233
+ for divisor in range(min_divisor, int(math.sqrt(dividend)) + 1):
234
+ if dividend % divisor == 0:
235
+ return divisor
236
+ return dividend
237
+
238
+ def next_divisor_large(dividend, min_divisor):
239
+ max_inv_divisor = dividend // min_divisor
240
+ for inv_divisor in range(max_inv_divisor, 0, -1):
241
+ if dividend % inv_divisor == 0:
242
+ return dividend // inv_divisor
243
+ return dividend
244
+
245
+ def next_divisor(dividend, min_divisor):
246
+ if min_divisor * min_divisor <= dividend:
247
+ return next_divisor_small(dividend, min_divisor)
248
+ return next_divisor_large(dividend, min_divisor)
249
+
250
+ def move_axes_out(axes, tree):
251
+ def move_axis_out(axis, x):
252
+ if axis != 0:
253
+ return jnp.moveaxis(x, axis, 0)
254
+ return x
255
+ return tree_util.tree_map(move_axis_out, axes, tree)
256
+
257
+ def move_axes_in(axes, tree):
258
+ def move_axis_in(axis, x):
259
+ if axis != 0:
260
+ return jnp.moveaxis(x, 0, axis)
261
+ return x
262
+ return tree_util.tree_map(move_axis_in, axes, tree)
263
+
264
+ def batch(tree, nbatches):
265
+ def batch(x):
266
+ return x.reshape((nbatches, x.shape[0] // nbatches) + x.shape[1:])
267
+ return tree_util.tree_map(batch, tree)
268
+
269
+ def unbatch(tree):
270
+ def unbatch(x):
271
+ return x.reshape((x.shape[0] * x.shape[1],) + x.shape[2:])
272
+ return tree_util.tree_map(unbatch, tree)
273
+
274
+ def check_same(tree1, tree2):
275
+ def check_same(x1, x2):
276
+ assert x1.shape == x2.shape
277
+ assert x1.dtype == x2.dtype
278
+ tree_util.tree_map(check_same, tree1, tree2)
279
+
280
+ initial_in_axes = in_axes
281
+ initial_out_axes = out_axes
282
+
283
+ @jax.jit
284
+ @functools.wraps(func)
285
+ def batched_func(*args):
286
+ example_result = jax.eval_shape(func, *args)
287
+
288
+ in_axes = expand_axes(initial_in_axes, args)
289
+ out_axes = expand_axes(initial_out_axes, example_result)
290
+
291
+ in_size = extract_size(in_axes, args)
292
+ out_size = extract_size(out_axes, example_result)
293
+ assert in_size == out_size
294
+ size = in_size
295
+
296
+ total_nbytes = sum_nbytes(args) + sum_nbytes(example_result)
297
+ min_nbatches = total_nbytes // max_io_nbytes + bool(total_nbytes % max_io_nbytes)
298
+ nbatches = next_divisor(size, min_nbatches)
299
+ assert 1 <= nbatches <= size
300
+ assert size % nbatches == 0
301
+ assert total_nbytes % nbatches == 0
302
+
303
+ batch_nbytes = total_nbytes // nbatches
304
+ if batch_nbytes > max_io_nbytes:
305
+ assert size == nbatches
306
+ warnings.warn(f'batch_nbytes = {batch_nbytes} > max_io_nbytes = {max_io_nbytes}')
307
+
308
+ def loop(_, args):
309
+ args = move_axes_in(in_axes, args)
310
+ result = func(*args)
311
+ result = move_axes_out(out_axes, result)
312
+ return None, result
313
+
314
+ args = move_axes_out(in_axes, args)
315
+ args = batch(args, nbatches)
316
+ _, result = lax.scan(loop, None, args)
317
+ result = unbatch(result)
318
+ result = move_axes_in(out_axes, result)
319
+
320
+ check_same(example_result, result)
321
+
322
+ if return_nbatches:
323
+ return result, nbatches
324
+ return result
325
+
326
+ return batched_func
327
+
328
+ @tree_util.register_pytree_node_class
329
+ class LeafDict(dict):
330
+ """ dictionary that acts as a leaf in jax pytrees, to store compile-time
331
+ values """
332
+
333
+ def tree_flatten(self):
334
+ return (), self
335
+
336
+ @classmethod
337
+ def tree_unflatten(cls, aux_data, children):
338
+ return aux_data
339
+
340
+ def __repr__(self):
341
+ return f'{__class__.__name__}({super().__repr__()})'
bartz/mcmcloop.py CHANGED
@@ -52,7 +52,7 @@ def run_mcmc(bart, n_burn, n_save, n_skip, callback, key):
52
52
  n_save : int
53
53
  The number of iterations to save.
54
54
  n_skip : int
55
- The number of iterations to skip between each saved iteration.
55
+ The number of iterations to skip between each saved iteration, plus 1.
56
56
  callback : callable
57
57
  An arbitrary function run at each iteration, called with the following
58
58
  arguments, passed by keyword:
@@ -105,16 +105,19 @@ def run_mcmc(bart, n_burn, n_save, n_skip, callback, key):
105
105
  output = {key: bart[key] for key in tracelist}
106
106
  return (bart, i_total + 1, i_skip + 1, key), output
107
107
 
108
+ def empty_trace(bart, tracelist):
109
+ return {
110
+ key: jnp.empty((0,) + bart[key].shape, bart[key].dtype)
111
+ for key in tracelist
112
+ }
113
+
108
114
  if n_burn > 0:
109
115
  carry = bart, 0, 0, key
110
116
  burnin_loop = functools.partial(inner_loop, tracelist=tracelist_burnin, burnin=True)
111
117
  (bart, i_total, _, key), burnin_trace = lax.scan(burnin_loop, carry, None, n_burn)
112
118
  else:
113
119
  i_total = 0
114
- burnin_trace = {
115
- key: jnp.empty((0,) + bart[key].shape, bart[key].dtype)
116
- for key in tracelist_burnin
117
- }
120
+ burnin_trace = empty_trace(bart, tracelist_burnin)
118
121
 
119
122
  def outer_loop(carry, _):
120
123
  bart, i_total, key = carry
@@ -124,8 +127,11 @@ def run_mcmc(bart, n_burn, n_save, n_skip, callback, key):
124
127
  output = {key: bart[key] for key in tracelist_main}
125
128
  return (bart, i_total, key), output
126
129
 
127
- carry = bart, i_total, key
128
- (bart, _, _), main_trace = lax.scan(outer_loop, carry, None, n_save)
130
+ if n_save > 0:
131
+ carry = bart, i_total, key
132
+ (bart, _, _), main_trace = lax.scan(outer_loop, carry, None, n_save)
133
+ else:
134
+ main_trace = empty_trace(bart, tracelist_main)
129
135
 
130
136
  return bart, burnin_trace, main_trace
131
137
 
@@ -133,7 +139,8 @@ def run_mcmc(bart, n_burn, n_save, n_skip, callback, key):
133
139
 
134
140
  @functools.lru_cache
135
141
  # cache to make the callback function object unique, such that the jit
136
- # of run_mcmc recognizes it
142
+ # of run_mcmc recognizes it => with the callback state, I can make
143
+ # printevery a runtime quantity
137
144
  def make_simple_print_callback(printevery):
138
145
  """
139
146
  Create a logging callback function for MCMC iterations.
@@ -155,11 +162,12 @@ def make_simple_print_callback(printevery):
155
162
  grow_acc = bart['grow_acc_count'] / bart['grow_prop_count']
156
163
  prune_acc = bart['prune_acc_count'] / bart['prune_prop_count']
157
164
  n_total = n_burn + n_save * n_skip
158
- debug.callback(simple_print_callback_impl, burnin, i_total, n_total, grow_prop, grow_acc, prune_prop, prune_acc, printevery)
165
+ printcond = (i_total + 1) % printevery == 0
166
+ debug.callback(_simple_print_callback, burnin, i_total, n_total, grow_prop, grow_acc, prune_prop, prune_acc, printcond)
159
167
  return callback
160
168
 
161
- def simple_print_callback_impl(burnin, i_total, n_total, grow_prop, grow_acc, prune_prop, prune_acc, printevery):
162
- if (i_total + 1) % printevery == 0:
169
+ def _simple_print_callback(burnin, i_total, n_total, grow_prop, grow_acc, prune_prop, prune_acc, printcond):
170
+ if printcond:
163
171
  burnin_flag = ' (burnin)' if burnin else ''
164
172
  total_str = str(n_total)
165
173
  ndigits = len(total_str)
bartz/mcmcstep.py CHANGED
@@ -34,7 +34,6 @@ range of possible values.
34
34
  """
35
35
 
36
36
  import functools
37
- import math
38
37
 
39
38
  import jax
40
39
  from jax import random
@@ -55,6 +54,7 @@ def init(*,
55
54
  small_float=jnp.float32,
56
55
  large_float=jnp.float32,
57
56
  min_points_per_leaf=None,
57
+ suffstat_batch_size='auto',
58
58
  ):
59
59
  """
60
60
  Make a BART posterior sampling MCMC initial state.
@@ -82,6 +82,9 @@ def init(*,
82
82
  The dtype for scalars, small arrays, and arrays which require accuracy.
83
83
  min_points_per_leaf : int, optional
84
84
  The minimum number of data points in a leaf node. 0 if not specified.
85
+ suffstat_batch_size : int, None, str, default 'auto'
86
+ The batch size for computing sufficient statistics. `None` for no
87
+ batching. If 'auto', pick a value based on the device of `y`.
85
88
 
86
89
  Returns
87
90
  -------
@@ -104,8 +107,9 @@ def init(*,
104
107
  The number of grow/prune proposals made during one full MCMC cycle.
105
108
  'grow_acc_count', 'prune_acc_count' : int
106
109
  The number of grow/prune moves accepted during one full MCMC cycle.
107
- 'p_nonterminal' : large_float array (d - 1,)
108
- The probability of a nonterminal node at each depth.
110
+ 'p_nonterminal' : large_float array (d,)
111
+ The probability of a nonterminal node at each depth, padded with a
112
+ zero.
109
113
  'sigma2_alpha' : large_float
110
114
  The shape parameter of the inverse gamma prior on the noise variance.
111
115
  'sigma2_beta' : large_float
@@ -121,18 +125,36 @@ def init(*,
121
125
  'affluence_trees' : bool array (num_trees, 2 ** (d - 1)) or None
122
126
  Whether a non-bottom leaf nodes contains twice `min_points_per_leaf`
123
127
  datapoints. If `min_points_per_leaf` is not specified, this is None.
128
+ 'opt' : LeafDict
129
+ A dictionary with config values:
130
+
131
+ 'suffstat_batch_size' : int or None
132
+ The batch size for computing sufficient statistics.
133
+ 'small_float' : dtype
134
+ The dtype for large arrays used in the algorithm.
135
+ 'large_float' : dtype
136
+ The dtype for scalars, small arrays, and arrays which require
137
+ accuracy.
138
+ 'require_min_points' : bool
139
+ Whether the `min_points_per_leaf` parameter is specified.
124
140
  """
125
141
 
126
142
  p_nonterminal = jnp.asarray(p_nonterminal, large_float)
127
- max_depth = p_nonterminal.size + 1
143
+ p_nonterminal = jnp.pad(p_nonterminal, (0, 1))
144
+ max_depth = p_nonterminal.size
128
145
 
129
146
  @functools.partial(jax.vmap, in_axes=None, out_axes=0, axis_size=num_trees)
130
147
  def make_forest(max_depth, dtype):
131
148
  return grove.make_tree(max_depth, dtype)
132
149
 
150
+ small_float = jnp.dtype(small_float)
151
+ large_float = jnp.dtype(large_float)
152
+ y = jnp.asarray(y, small_float)
153
+ suffstat_batch_size = _choose_suffstat_batch_size(suffstat_batch_size, y)
154
+
133
155
  bart = dict(
134
156
  leaf_trees=make_forest(max_depth, small_float),
135
- var_trees=make_forest(max_depth - 1, grove.minimal_unsigned_dtype(X.shape[0] - 1)),
157
+ var_trees=make_forest(max_depth - 1, jaxext.minimal_unsigned_dtype(X.shape[0] - 1)),
136
158
  split_trees=make_forest(max_depth - 1, max_split.dtype),
137
159
  resid=jnp.asarray(y, large_float),
138
160
  sigma2=jnp.ones((), large_float),
@@ -143,9 +165,9 @@ def init(*,
143
165
  p_nonterminal=p_nonterminal,
144
166
  sigma2_alpha=jnp.asarray(sigma2_alpha, large_float),
145
167
  sigma2_beta=jnp.asarray(sigma2_beta, large_float),
146
- max_split=max_split,
147
- y=jnp.asarray(y, small_float),
148
- X=X,
168
+ max_split=jnp.asarray(max_split),
169
+ y=y,
170
+ X=jnp.asarray(X),
149
171
  min_points_per_leaf=(
150
172
  None if min_points_per_leaf is None else
151
173
  jnp.asarray(min_points_per_leaf)
@@ -154,10 +176,32 @@ def init(*,
154
176
  None if min_points_per_leaf is None else
155
177
  make_forest(max_depth - 1, bool).at[:, 1].set(y.size >= 2 * min_points_per_leaf)
156
178
  ),
179
+ opt=jaxext.LeafDict(
180
+ suffstat_batch_size=suffstat_batch_size,
181
+ small_float=small_float,
182
+ large_float=large_float,
183
+ require_min_points=min_points_per_leaf is not None,
184
+ ),
157
185
  )
158
186
 
159
187
  return bart
160
188
 
189
+ def _choose_suffstat_batch_size(size, y):
190
+ if size == 'auto':
191
+ platform = y.devices().pop().platform
192
+ if platform == 'cpu':
193
+ return None
194
+ # maybe I should batch residuals (not counts) for numerical
195
+ # accuracy, even if it's slower
196
+ elif platform == 'gpu':
197
+ return 128 # 128 is good on A100, and V100 at high n
198
+ # 512 is good on T4, and V100 at low n
199
+ else:
200
+ raise KeyError(f'Unknown platform: {platform}')
201
+ elif size is not None:
202
+ return int(size)
203
+ return size
204
+
161
205
  def step(bart, key):
162
206
  """
163
207
  Perform one full MCMC step on a BART state.
@@ -196,11 +240,14 @@ def sample_trees(bart, key):
196
240
 
197
241
  Notes
198
242
  -----
199
- This function zeroes the proposal counters.
243
+ This function zeroes the proposal counters before using them.
200
244
  """
245
+ bart = bart.copy()
201
246
  key, subkey = random.split(key)
202
247
  grow_moves, prune_moves = sample_moves(bart, subkey)
203
- return accept_moves_and_sample_leaves(bart, grow_moves, prune_moves, key)
248
+ bart['var_trees'] = grow_moves['var_tree']
249
+ grow_leaf_indices = grove.traverse_forest(bart['X'], grow_moves['var_tree'], grow_moves['split_tree'])
250
+ return accept_moves_and_sample_leaves(bart, grow_moves, prune_moves, grow_leaf_indices, key)
204
251
 
205
252
  def sample_moves(bart, key):
206
253
  """
@@ -216,20 +263,7 @@ def sample_moves(bart, key):
216
263
  Returns
217
264
  -------
218
265
  grow_moves, prune_moves : dict
219
- The proposals for grow and prune moves, with these fields:
220
-
221
- 'allowed' : bool array (num_trees,)
222
- Whether the move is possible.
223
- 'node' : int array (num_trees,)
224
- The index of the leaf to grow or node to prune.
225
- 'var_tree' : int array (num_trees, 2 ** (d - 1),)
226
- The new decision axes of the tree.
227
- 'split_tree' : int array (num_trees, 2 ** (d - 1),)
228
- The new decision boundaries of the tree.
229
- 'partial_ratio' : float array (num_trees,)
230
- A factor of the Metropolis-Hastings ratio of the move. It lacks
231
- the likelihood ratio, and the probability of proposing the prune
232
- move. For the prune move, the ratio is inverted.
266
+ The proposals for grow and prune moves. See `grow_move` and `prune_move`.
233
267
  """
234
268
  key = random.split(key, bart['var_trees'].shape[0])
235
269
  return sample_moves_vmap_trees(bart['var_trees'], bart['split_trees'], bart['affluence_trees'], bart['max_split'], bart['p_nonterminal'], key)
@@ -260,7 +294,7 @@ def grow_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, ke
260
294
  Whether a leaf has enough points to be grown.
261
295
  max_split : array (p,)
262
296
  The maximum split index for each variable.
263
- p_nonterminal : array (d - 1,)
297
+ p_nonterminal : array (d,)
264
298
  The probability of a nonterminal node at each depth.
265
299
  key : jax.dtypes.prng_key array
266
300
  A jax random key.
@@ -292,16 +326,16 @@ def grow_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, ke
292
326
  var_tree = var_tree.at[leaf_to_grow].set(var.astype(var_tree.dtype))
293
327
 
294
328
  split = choose_split(var_tree, split_tree, max_split, leaf_to_grow, key2)
295
- new_split_tree = split_tree.at[leaf_to_grow].set(split.astype(split_tree.dtype))
329
+ split_tree = split_tree.at[leaf_to_grow].set(split.astype(split_tree.dtype))
296
330
 
297
- ratio = compute_partial_ratio(num_growable, num_prunable, p_nonterminal, leaf_to_grow, split_tree, new_split_tree)
331
+ ratio = compute_partial_ratio(num_growable, num_prunable, p_nonterminal, leaf_to_grow, split_tree)
298
332
 
299
333
  return dict(
300
334
  allowed=allowed,
301
335
  node=leaf_to_grow,
302
- var_tree=var_tree,
303
- split_tree=new_split_tree,
304
336
  partial_ratio=ratio,
337
+ var_tree=var_tree,
338
+ split_tree=split_tree,
305
339
  )
306
340
 
307
341
  def choose_leaf(split_tree, affluence_tree, key):
@@ -464,7 +498,7 @@ def ancestor_variables(var_tree, max_split, node_index):
464
498
  the parent. Unused spots are filled with `p`.
465
499
  """
466
500
  max_num_ancestors = grove.tree_depth(var_tree) - 1
467
- ancestor_vars = jnp.zeros(max_num_ancestors, grove.minimal_unsigned_dtype(max_split.size))
501
+ ancestor_vars = jnp.zeros(max_num_ancestors, jaxext.minimal_unsigned_dtype(max_split.size))
468
502
  carry = ancestor_vars.size - 1, node_index, ancestor_vars
469
503
  def loop(carry, _):
470
504
  i, index, ancestor_vars = carry
@@ -569,7 +603,7 @@ def choose_split(var_tree, split_tree, max_split, leaf_index, key):
569
603
  l, r = split_range(var_tree, split_tree, max_split, leaf_index, var)
570
604
  return random.randint(key, (), l, r)
571
605
 
572
- def compute_partial_ratio(num_growable, num_prunable, p_nonterminal, leaf_to_grow, initial_split_tree, new_split_tree):
606
+ def compute_partial_ratio(num_growable, num_prunable, p_nonterminal, leaf_to_grow, new_split_tree):
573
607
  """
574
608
  Compute the product of the transition and prior ratios of a grow move.
575
609
 
@@ -580,12 +614,10 @@ def compute_partial_ratio(num_growable, num_prunable, p_nonterminal, leaf_to_gro
580
614
  num_prunable : int
581
615
  The number of leaf parents that could be pruned, after converting the
582
616
  leaf to be grown to a non-terminal node.
583
- p_nonterminal : array (d - 1,)
617
+ p_nonterminal : array (d,)
584
618
  The probability of a nonterminal node at each depth.
585
619
  leaf_to_grow : int
586
620
  The index of the leaf to grow.
587
- initial_split_tree : array (2 ** (d - 1),)
588
- The splitting points of the tree, before the leaf is grown.
589
621
  new_split_tree : array (2 ** (d - 1),)
590
622
  The splitting points of the tree, after the leaf is grown.
591
623
 
@@ -600,14 +632,18 @@ def compute_partial_ratio(num_growable, num_prunable, p_nonterminal, leaf_to_gro
600
632
  # the two ratios also contain factors num_available_split *
601
633
  # num_available_var, but they cancel out
602
634
 
603
- prune_was_allowed = prune_allowed(initial_split_tree)
604
- p_grow = jnp.where(prune_was_allowed, 0.5, 1)
635
+ prune_allowed = leaf_to_grow != 1
636
+ # prune allowed <---> the initial tree is not a root
637
+ # leaf to grow is root --> the tree can only be a root
638
+ # tree is a root --> the only leaf I can grow is root
639
+
640
+ p_grow = jnp.where(prune_allowed, 0.5, 1)
605
641
 
606
642
  trans_ratio = num_growable / (p_grow * num_prunable)
607
643
 
608
- depth = grove.tree_depths(initial_split_tree.size)[leaf_to_grow]
644
+ depth = grove.tree_depths(new_split_tree.size)[leaf_to_grow]
609
645
  p_parent = p_nonterminal[depth]
610
- cp_children = 1 - p_nonterminal.at[depth + 1].get(mode='fill', fill_value=0)
646
+ cp_children = 1 - p_nonterminal[depth + 1]
611
647
  tree_ratio = cp_children * cp_children * p_parent / (1 - p_parent)
612
648
 
613
649
  return trans_ratio * tree_ratio
@@ -626,7 +662,7 @@ def prune_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, k
626
662
  Whether a leaf has enough points to be grown.
627
663
  max_split : array (p,)
628
664
  The maximum split index for each variable.
629
- p_nonterminal : array (d - 1,)
665
+ p_nonterminal : array (d,)
630
666
  The probability of a nonterminal node at each depth.
631
667
  key : jax.dtypes.prng_key array
632
668
  A jax random key.
@@ -639,28 +675,20 @@ def prune_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, k
639
675
  'allowed' : bool
640
676
  Whether the move is possible.
641
677
  'node' : int
642
- The index of the leaf to grow.
643
- 'var_tree' : array (2 ** (d - 1),)
644
- The new decision axes of the tree.
645
- 'split_tree' : array (2 ** (d - 1),)
646
- The new decision boundaries of the tree.
678
+ The index of the node to prune.
647
679
  'partial_ratio' : float
648
680
  A factor of the Metropolis-Hastings ratio of the move. It lacks
649
681
  the likelihood ratio and the probability of proposing the prune
650
682
  move. This ratio is inverted.
651
683
  """
652
684
  node_to_prune, num_prunable, num_growable = choose_leaf_parent(split_tree, affluence_tree, key)
653
- allowed = prune_allowed(split_tree)
685
+ allowed = split_tree[1].astype(bool) # allowed iff the tree is not a root
654
686
 
655
- new_split_tree = split_tree.at[node_to_prune].set(0)
656
-
657
- ratio = compute_partial_ratio(num_growable, num_prunable, p_nonterminal, node_to_prune, new_split_tree, split_tree)
687
+ ratio = compute_partial_ratio(num_growable, num_prunable, p_nonterminal, node_to_prune, split_tree)
658
688
 
659
689
  return dict(
660
690
  allowed=allowed,
661
691
  node=node_to_prune,
662
- var_tree=var_tree,
663
- split_tree=new_split_tree,
664
692
  partial_ratio=ratio, # it is inverted in accept_move_and_sample_leaves
665
693
  )
666
694
 
@@ -702,29 +730,37 @@ def choose_leaf_parent(split_tree, affluence_tree, key):
702
730
 
703
731
  return node_to_prune, num_prunable, num_growable
704
732
 
705
- def prune_allowed(split_tree):
733
+ def accept_moves_and_sample_leaves(bart, grow_moves, prune_moves, grow_leaf_indices, key):
706
734
  """
707
- Return whether a prune move is allowed.
735
+ Accept or reject the proposed moves and sample the new leaf values.
708
736
 
709
737
  Parameters
710
738
  ----------
711
- split_tree : array (2 ** (d - 1),)
712
- The splitting points of the tree.
739
+ bart : dict
740
+ A BART mcmc state.
741
+ grow_moves : dict
742
+ The proposals for grow moves, batched over the first axis. See
743
+ `grow_move`.
744
+ prune_moves : dict
745
+ The proposals for prune moves, batched over the first axis. See
746
+ `prune_move`.
747
+ grow_leaf_indices : int array (num_trees, n)
748
+ The leaf indices of the trees proposed by the grow move.
749
+ key : jax.dtypes.prng_key array
750
+ A jax random key.
713
751
 
714
752
  Returns
715
753
  -------
716
- allowed : bool
717
- Whether a prune move is allowed.
754
+ bart : dict
755
+ The new BART mcmc state.
718
756
  """
719
- return split_tree.at[1].get(mode='fill', fill_value=0).astype(bool)
720
-
721
- def accept_moves_and_sample_leaves(bart, grow_moves, prune_moves, key):
722
757
  bart = bart.copy()
723
758
  def loop(carry, item):
724
759
  resid = carry.pop('resid')
725
760
  resid, carry, trees = accept_move_and_sample_leaves(
726
761
  bart['X'],
727
762
  len(bart['leaf_trees']),
763
+ bart['opt']['suffstat_batch_size'],
728
764
  resid,
729
765
  bart['sigma2'],
730
766
  bart['min_points_per_leaf'],
@@ -740,11 +776,11 @@ def accept_moves_and_sample_leaves(bart, grow_moves, prune_moves, key):
740
776
  carry['resid'] = bart['resid']
741
777
  items = (
742
778
  bart['leaf_trees'],
743
- bart['var_trees'],
744
779
  bart['split_trees'],
745
780
  bart['affluence_trees'],
746
781
  grow_moves,
747
782
  prune_moves,
783
+ grow_leaf_indices,
748
784
  random.split(key, len(bart['leaf_trees'])),
749
785
  )
750
786
  carry, trees = lax.scan(loop, carry, items)
@@ -752,11 +788,50 @@ def accept_moves_and_sample_leaves(bart, grow_moves, prune_moves, key):
752
788
  bart.update(trees)
753
789
  return bart
754
790
 
755
- def accept_move_and_sample_leaves(X, ntree, resid, sigma2, min_points_per_leaf, counts, leaf_tree, var_tree, split_tree, affluence_tree, grow_move, prune_move, key):
756
-
757
- # compute leaf indices according to grow move tree
758
- traverse_tree = jax.vmap(grove.traverse_tree, in_axes=(1, None, None))
759
- grow_leaf_indices = traverse_tree(X, grow_move['var_tree'], grow_move['split_tree'])
791
+ def accept_move_and_sample_leaves(X, ntree, suffstat_batch_size, resid, sigma2, min_points_per_leaf, counts, leaf_tree, split_tree, affluence_tree, grow_move, prune_move, grow_leaf_indices, key):
792
+ """
793
+ Accept or reject a proposed move and sample the new leaf values.
794
+
795
+ Parameters
796
+ ----------
797
+ X : int array (p, n)
798
+ The predictors.
799
+ ntree : int
800
+ The number of trees in the forest.
801
+ suffstat_batch_size : int, None
802
+ The batch size for computing sufficient statistics.
803
+ resid : float array (n,)
804
+ The residuals (data minus forest value).
805
+ sigma2 : float
806
+ The noise variance.
807
+ min_points_per_leaf : int or None
808
+ The minimum number of data points in a leaf node.
809
+ counts : dict
810
+ The acceptance counts from the mcmc state dict.
811
+ leaf_tree : float array (2 ** d,)
812
+ The leaf values of the tree.
813
+ split_tree : int array (2 ** (d - 1),)
814
+ The decision boundaries of the tree.
815
+ affluence_tree : bool array (2 ** (d - 1),) or None
816
+ Whether a leaf has enough points to be grown.
817
+ grow_move : dict
818
+ The proposal for the grow move. See `grow_move`.
819
+ prune_move : dict
820
+ The proposal for the prune move. See `prune_move`.
821
+ grow_leaf_indices : int array (n,)
822
+ The leaf indices of the tree proposed by the grow move.
823
+ key : jax.dtypes.prng_key array
824
+ A jax random key.
825
+
826
+ Returns
827
+ -------
828
+ resid : float array (n,)
829
+ The updated residuals (data minus forest value).
830
+ counts : dict
831
+ The updated acceptance counts.
832
+ trees : dict
833
+ The updated tree arrays.
834
+ """
760
835
 
761
836
  # compute leaf indices in starting tree
762
837
  grow_node = grow_move['node']
@@ -782,10 +857,7 @@ def accept_move_and_sample_leaves(X, ntree, resid, sigma2, min_points_per_leaf,
782
857
  resid += leaf_tree[leaf_indices]
783
858
 
784
859
  # aggregate residuals and count units per leaf
785
- grow_resid_tree = jnp.zeros_like(leaf_tree, sigma2.dtype)
786
- grow_resid_tree = grow_resid_tree.at[grow_leaf_indices].add(resid)
787
- grow_count_tree = jnp.zeros_like(leaf_tree, grove.minimal_unsigned_dtype(resid.size))
788
- grow_count_tree = grow_count_tree.at[grow_leaf_indices].add(1)
860
+ grow_resid_tree, grow_count_tree = sufficient_stat(resid, grow_leaf_indices, leaf_tree.size, suffstat_batch_size)
789
861
 
790
862
  # compute aggregations in starting tree
791
863
  # I do not zero the children because garbage there does not matter
@@ -833,10 +905,10 @@ def accept_move_and_sample_leaves(X, ntree, resid, sigma2, min_points_per_leaf,
833
905
 
834
906
  # pick trees for chosen move
835
907
  trees = {}
836
- var_tree = jnp.where(do_grow, grow_move['var_tree'], var_tree)
837
908
  split_tree = jnp.where(do_grow, grow_move['split_tree'], split_tree)
838
- var_tree = jnp.where(do_prune, prune_move['var_tree'], var_tree)
839
- split_tree = jnp.where(do_prune, prune_move['split_tree'], split_tree)
909
+ # the prune var tree is equal to the initial one, because I leave garbage values behind
910
+ split_tree = split_tree.at[prune_node].set(
911
+ jnp.where(do_prune, 0, split_tree[prune_node]))
840
912
  if min_points_per_leaf is not None:
841
913
  affluence_tree = jnp.where(do_grow, grow_affluence_tree, affluence_tree)
842
914
  affluence_tree = jnp.where(do_prune, prune_affluence_tree, affluence_tree)
@@ -869,13 +941,60 @@ def accept_move_and_sample_leaves(X, ntree, resid, sigma2, min_points_per_leaf,
869
941
  # pack trees
870
942
  trees = {
871
943
  'leaf_trees': leaf_tree,
872
- 'var_trees': var_tree,
873
944
  'split_trees': split_tree,
874
945
  'affluence_trees': affluence_tree,
875
946
  }
876
947
 
877
948
  return resid, counts, trees
878
949
 
950
+ def sufficient_stat(resid, leaf_indices, tree_size, batch_size):
951
+ """
952
+ Compute the sufficient statistics for the likelihood ratio of a tree move.
953
+
954
+ Parameters
955
+ ----------
956
+ resid : float array (n,)
957
+ The residuals (data minus forest value).
958
+ leaf_indices : int array (n,)
959
+ The leaf indices of the tree (in which leaf each data point falls into).
960
+ tree_size : int
961
+ The size of the tree array (2 ** d).
962
+ batch_size : int, None
963
+ The batch size for the aggregation. Batching increases numerical
964
+ accuracy and parallelism.
965
+
966
+ Returns
967
+ -------
968
+ resid_tree : float array (2 ** d,)
969
+ The sum of the residuals at data points in each leaf.
970
+ count_tree : int array (2 ** d,)
971
+ The number of data points in each leaf.
972
+ """
973
+ if batch_size is None:
974
+ aggr_func = _aggregate_scatter
975
+ else:
976
+ aggr_func = functools.partial(_aggregate_batched, batch_size=batch_size)
977
+ resid_tree = aggr_func(resid, leaf_indices, tree_size, jnp.float32)
978
+ count_tree = aggr_func(1, leaf_indices, tree_size, jnp.uint32)
979
+ return resid_tree, count_tree
980
+
981
+ def _aggregate_scatter(values, indices, size, dtype):
982
+ return (jnp
983
+ .zeros(size, dtype)
984
+ .at[indices]
985
+ .add(values)
986
+ )
987
+
988
+ def _aggregate_batched(values, indices, size, dtype, batch_size):
989
+ nbatches = indices.size // batch_size + bool(indices.size % batch_size)
990
+ batch_indices = jnp.arange(indices.size) // batch_size
991
+ return (jnp
992
+ .zeros((nbatches, size), dtype)
993
+ .at[batch_indices, indices]
994
+ .add(values)
995
+ .sum(axis=0)
996
+ )
997
+
879
998
  def compute_p_prune_back(new_split_tree, new_affluence_tree):
880
999
  """
881
1000
  Compute the probability of proposing a prune move after doing a grow move.
bartz/prepcovars.py CHANGED
@@ -27,8 +27,10 @@ import functools
27
27
  import jax
28
28
  from jax import numpy as jnp
29
29
 
30
+ from . import jaxext
30
31
  from . import grove
31
32
 
33
+ @functools.partial(jax.jit, static_argnums=(1,))
32
34
  def quantilized_splits_from_matrix(X, max_bins):
33
35
  """
34
36
  Determine bins that make the distribution of each predictor uniform.
@@ -52,48 +54,41 @@ def quantilized_splits_from_matrix(X, max_bins):
52
54
  The number of actually used values in each row of `splits`.
53
55
  """
54
56
  out_length = min(max_bins, X.shape[1]) - 1
55
- return quantilized_splits_from_matrix_impl(X, out_length)
57
+ # return _quantilized_splits_from_matrix(X, out_length)
58
+ @functools.partial(jaxext.autobatch, max_io_nbytes=500_000_000)
59
+ def func(X):
60
+ return _quantilized_splits_from_matrix(X, out_length)
61
+ return func(X)
56
62
 
57
63
  @functools.partial(jax.vmap, in_axes=(0, None))
58
- def quantilized_splits_from_matrix_impl(x, out_length):
59
- huge = huge_value(x)
60
- u = jnp.unique(x, size=x.size, fill_value=huge)
61
- actual_length = jnp.count_nonzero(u < huge) - 1
62
- midpoints = (u[1:] + u[:-1]) / 2
64
+ def _quantilized_splits_from_matrix(x, out_length):
65
+ huge = jaxext.huge_value(x)
66
+ u, actual_length = jaxext.unique(x, size=x.size, fill_value=huge)
67
+ actual_length -= 1
68
+ if jnp.issubdtype(x.dtype, jnp.integer):
69
+ midpoints = u[:-1] + jaxext.ensure_unsigned(u[1:] - u[:-1]) // 2
70
+ indices = jnp.arange(midpoints.size, dtype=jaxext.minimal_unsigned_dtype(midpoints.size - 1))
71
+ midpoints = jnp.where(indices < actual_length, midpoints, huge)
72
+ else:
73
+ midpoints = (u[1:] + u[:-1]) / 2
63
74
  indices = jnp.linspace(-1, actual_length, out_length + 2)[1:-1]
64
- indices = jnp.around(indices).astype(grove.minimal_unsigned_dtype(midpoints.size - 1))
75
+ indices = jnp.around(indices).astype(jaxext.minimal_unsigned_dtype(midpoints.size - 1))
65
76
  # indices calculation with float rather than int to avoid potential
66
77
  # overflow with int32, and to round to nearest instead of rounding down
67
78
  decimated_midpoints = midpoints[indices]
68
79
  truncated_midpoints = midpoints[:out_length]
69
80
  splits = jnp.where(actual_length > out_length, decimated_midpoints, truncated_midpoints)
70
81
  max_split = jnp.minimum(actual_length, out_length)
71
- max_split = max_split.astype(grove.minimal_unsigned_dtype(out_length))
82
+ max_split = max_split.astype(jaxext.minimal_unsigned_dtype(out_length))
72
83
  return splits, max_split
73
84
 
74
- def huge_value(x):
75
- """
76
- Return the maximum value that can be stored in `x`.
77
-
78
- Parameters
79
- ----------
80
- x : array
81
- A numerical numpy or jax array.
82
-
83
- Returns
84
- -------
85
- maxval : scalar
86
- The maximum value allowed by `x`'s type (+inf for floats).
87
- """
88
- if jnp.issubdtype(x.dtype, jnp.integer):
89
- return jnp.iinfo(x.dtype).max
90
- else:
91
- return jnp.inf
92
-
85
+ @jax.jit
93
86
  def bin_predictors(X, splits):
94
87
  """
95
88
  Bin the predictors according to the given splits.
96
89
 
90
+ A value ``x`` is mapped to bin ``i`` iff ``splits[i - 1] < x <= splits[i]``.
91
+
97
92
  Parameters
98
93
  ----------
99
94
  X : array (p, n)
@@ -110,9 +105,9 @@ def bin_predictors(X, splits):
110
105
  A matrix with `p` predictors and `n` observations, where each predictor
111
106
  has been replaced by the index of the bin it falls into.
112
107
  """
113
- return bin_predictors_impl(X, splits)
108
+ return _bin_predictors(X, splits)
114
109
 
115
110
  @jax.vmap
116
- def bin_predictors_impl(x, splits):
117
- dtype = grove.minimal_unsigned_dtype(splits.size)
111
+ def _bin_predictors(x, splits):
112
+ dtype = jaxext.minimal_unsigned_dtype(splits.size)
118
113
  return jnp.searchsorted(splits, x).astype(dtype)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: bartz
3
- Version: 0.1.0
3
+ Version: 0.2.0
4
4
  Summary: A JAX implementation of BART
5
5
  Home-page: https://github.com/Gattocrucco/bartz
6
6
  License: MIT
@@ -20,7 +20,13 @@ Project-URL: Bug Tracker, https://github.com/Gattocrucco/bartz/issues
20
20
  Project-URL: Repository, https://github.com/Gattocrucco/bartz
21
21
  Description-Content-Type: text/markdown
22
22
 
23
+ [![PyPI](https://img.shields.io/pypi/v/bartz)](https://pypi.org/project/bartz/)
24
+
23
25
  # BART vectoriZed
24
26
 
25
27
  A branchless vectorized implementation of Bayesian Additive Regression Trees (BART) in JAX.
26
28
 
29
+ BART is a nonparametric Bayesian regression technique. Given predictors $X$ and responses $y$, BART finds a function to predict $y$ given $X$. The result of the inference is a sample of possible functions, representing the uncertainty over the determination of the function.
30
+
31
+ This Python module provides an implementation of BART that runs on GPU, to process large datasets faster. It is also a good on CPU. Most other implementations of BART are for R, and run on CPU only.
32
+
@@ -0,0 +1,13 @@
1
+ bartz/BART.py,sha256=pRG7mALenknX2JHqY-VyhO9-evDgEC6hWBp4jpecBdM,15801
2
+ bartz/__init__.py,sha256=E96vsP0bZ8brejpZmEmRoXuMsUdinO_B_SKUUl1rLsg,1448
3
+ bartz/_version.py,sha256=FVHPBGkfhbQDi_z3v0PiKJrXXqXOx0vGW_1VaqNJi7U,22
4
+ bartz/debug.py,sha256=9ZH-JfwZVu5OPhHBEyXQHAU5H9KIu1vxLK7yNv4m4Ew,5314
5
+ bartz/grove.py,sha256=Wj_7jHl9w3uwuVdH4hoeXowimGpdRE2lGIzr4aDkzsI,8291
6
+ bartz/jaxext.py,sha256=VYA41D5F7DYcAAVtkcZtEN927HxQGOOQM-uGsgr2CPc,10996
7
+ bartz/mcmcloop.py,sha256=lheLrjVxmlyQzc_92zeNsFhdkrhEWQEjoAWFbVzknnw,7701
8
+ bartz/mcmcstep.py,sha256=3ba94hXBW4UAZ11SFshnwJAgn6bpIqSZdRy_wQjEkrk,39278
9
+ bartz/prepcovars.py,sha256=iiQ0WjSj4--l5DgPW626Qg2SSB6ljnaaUsBz_A8kFrI,4634
10
+ bartz-0.2.0.dist-info/LICENSE,sha256=heuIJZQK9IexJYC-fYHoLUrgj8HG8yS3G072EvKh-94,1073
11
+ bartz-0.2.0.dist-info/METADATA,sha256=LiYjTAzgoxUM2MAuaKtf0VW-_zciTKBkTX5B7HNvUbI,1490
12
+ bartz-0.2.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
13
+ bartz-0.2.0.dist-info/RECORD,,
@@ -1,13 +0,0 @@
1
- bartz/__init__.py,sha256=40tX5XHoTiGnZcoeogVpyNOM_5rbHt-Y6zTI0NS7OA4,1345
2
- bartz/_version.py,sha256=IMjkMO3twhQzluVTo8Z6rE7Eg-9U79_LGKMcsWLKBkY,22
3
- bartz/debug.py,sha256=_HOjDieipAgliP6B6C0UMgz-mVgmeZ3zmtzVe-iMGtY,5289
4
- bartz/grove.py,sha256=LHhnvNKLb-jxUf4YjP927Hf9txkXynhMZ2ejtMRWZl4,8353
5
- bartz/interface.py,sha256=INyNuHzFySwXAsXofVZDpTsMv78AR_3VCvAHbZFh92c,15724
6
- bartz/jaxext.py,sha256=FK5j1zfW1yR4-yPKcD7ZvKSkVQ5--jHjQpVCl4n4gXY,2844
7
- bartz/mcmcloop.py,sha256=xTxC1AkNX8jCrMArblvlMjnjMh80q1M3a6ZGrDdfsFI,7423
8
- bartz/mcmcstep.py,sha256=6zkpTqgIrapeVy9mhy6BlsIO0s26HwBRDfw_6dVMmZA,35207
9
- bartz/prepcovars.py,sha256=3ddDOtNNop3Ba2Kgy_dZ6apFydtwaEXH3uXSmmKf9Fs,4421
10
- bartz-0.1.0.dist-info/LICENSE,sha256=heuIJZQK9IexJYC-fYHoLUrgj8HG8yS3G072EvKh-94,1073
11
- bartz-0.1.0.dist-info/METADATA,sha256=8YYlbCf7frDtT2of6tNlnBbuGqyO8YyYlED8OXSiBpA,933
12
- bartz-0.1.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
13
- bartz-0.1.0.dist-info/RECORD,,
File without changes
File without changes