bartz 0.0.1__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,4 @@
1
- # bartz/src/bartz/interface.py
1
+ # bartz/src/bartz/BART.py
2
2
  #
3
3
  # Copyright (c) 2024, Giacomo Petrillo
4
4
  #
@@ -33,12 +33,12 @@ from . import mcmcstep
33
33
  from . import mcmcloop
34
34
  from . import prepcovars
35
35
 
36
- class BART:
36
+ class gbart:
37
37
  """
38
38
  Nonparametric regression with Bayesian Additive Regression Trees (BART).
39
39
 
40
40
  Regress `y_train` on `x_train` with a latent mean function represented as
41
- a sum of decision trees. The inference is carried out by estimating the
41
+ a sum of decision trees. The inference is carried out by sampling the
42
42
  posterior distribution of the tree ensemble with an MCMC.
43
43
 
44
44
  Parameters
@@ -86,7 +86,7 @@ class BART:
86
86
  predictor is binned such that its distribution in `x_train` is
87
87
  approximately uniform across bins. The number of bins is at most the
88
88
  number of unique values appearing in `x_train`, or ``numcut + 1``.
89
- Before running the algorithm, the predictors are compressed to th
89
+ Before running the algorithm, the predictors are compressed to the
90
90
  smallest integer type that fits the bin indices, so `numcut` is best set
91
91
  to the maximum value of an unsigned integer type.
92
92
  ndpost : int, default 1000
@@ -133,7 +133,7 @@ class BART:
133
133
 
134
134
  Notes
135
135
  -----
136
- This interface imitates the function `wbart` from the R package `BART
136
+ This interface imitates the function `gbart` from the R package `BART
137
137
  <https://cran.r-project.org/package=BART>`_, but with these differences:
138
138
 
139
139
  - If `x_train` and `x_test` are matrices, they have one predictor per row
@@ -142,6 +142,7 @@ class BART:
142
142
  - `usequants` is always `True`.
143
143
  - `rm_const` is always `False`.
144
144
  - The default `numcut` is 255 instead of 100.
145
+ - A lot of functionality is missing (variable selection, discrete response).
145
146
  - There are some additional attributes, and some missing.
146
147
  """
147
148
 
@@ -321,7 +322,7 @@ class BART:
321
322
  p_nonterminal = base / (1 + depth).astype(float) ** power
322
323
  sigma2_alpha = sigdf / 2
323
324
  sigma2_beta = lamda * sigma2_alpha
324
- return mcmcstep.make_bart(
325
+ return mcmcstep.init(
325
326
  X=x_train,
326
327
  y=y_train,
327
328
  max_split=max_split,
@@ -354,13 +355,6 @@ class BART:
354
355
  return scale * jnp.sqrt(trace['sigma2'])
355
356
 
356
357
 
357
- def _predict_debug(self, x_test):
358
- from . import debug
359
- x_test, x_test_fmt = self._process_predictor_input(x_test)
360
- self._check_compatible_formats(x_test_fmt, self._x_train_fmt)
361
- x_test = self._bin_predictors(x_test, self._splits)
362
- return debug.trace_evaluate_trees(self._main_trace, x_test)
363
-
364
358
  def _show_tree(self, i_sample, i_tree, print_all=False):
365
359
  from . import debug
366
360
  trace = self._main_trace
@@ -385,7 +379,7 @@ class BART:
385
379
  def _compare_resid(self):
386
380
  bart = self._mcmc_state
387
381
  resid1 = bart['resid']
388
- yhat = grove.evaluate_tree_vmap_x(bart['X'], bart['leaf_trees'], bart['var_trees'], bart['split_trees'], jnp.float32)
382
+ yhat = grove.evaluate_forest(bart['X'], bart['leaf_trees'], bart['var_trees'], bart['split_trees'], jnp.float32)
389
383
  resid2 = bart['y'] - yhat
390
384
  return resid1, resid2
391
385
 
@@ -427,7 +421,5 @@ class BART:
427
421
 
428
422
  def _tree_goes_bad(self):
429
423
  bad = self._check_trees().astype(bool)
430
- bad_before = bad[:-1]
431
- bad_after = bad[1:]
432
- goes_bad = bad_after & ~bad_before
433
- return jnp.pad(goes_bad, [(1, 0), (0, 0)])
424
+ bad_before = jnp.pad(bad[:-1], [(1, 0), (0, 0)])
425
+ return bad & ~bad_before
bartz/__init__.py CHANGED
@@ -28,8 +28,13 @@ A jax implementation of BART
28
28
  See the manual at https://gattocrucco.github.io/bartz/docs
29
29
  """
30
30
 
31
- __version__ = '0.0.1'
31
+ from ._version import __version__
32
32
 
33
- from .interface import BART
33
+ from . import BART
34
34
 
35
35
  from . import debug
36
+ from . import grove
37
+ from . import mcmcstep
38
+ from . import mcmcloop
39
+ from . import prepcovars
40
+ from . import jaxext
bartz/_version.py ADDED
@@ -0,0 +1 @@
1
+ __version__ = '0.2.0'
bartz/debug.py CHANGED
@@ -6,22 +6,7 @@ from jax import lax
6
6
 
7
7
  from . import grove
8
8
  from . import mcmcstep
9
-
10
- def trace_evaluate_trees(bart, X):
11
- """
12
- Evaluate all trees, for all samples, at all x. Out axes:
13
- 0: mcmc sample
14
- 1: tree
15
- 2: X
16
- """
17
- def loop(_, bart):
18
- return None, evaluate_all_trees(X, bart['leaf_trees'], bart['var_trees'], bart['split_trees'])
19
- _, y = lax.scan(loop, None, bart)
20
- return y
21
-
22
- @functools.partial(jax.vmap, in_axes=(None, 0, 0, 0)) # vectorize over forest
23
- def evaluate_all_trees(X, leaf_trees, var_trees, split_trees):
24
- return grove.evaluate_tree_vmap_x(X, leaf_trees, var_trees, split_trees, jnp.float32)
9
+ from . import jaxext
25
10
 
26
11
  def print_tree(leaf_tree, var_tree, split_tree, print_all=False):
27
12
 
@@ -97,8 +82,10 @@ def trace_depth_distr(split_trees_trace):
97
82
  return jax.vmap(forest_depth_distr)(split_trees_trace)
98
83
 
99
84
  def points_per_leaf_distr(var_tree, split_tree, X):
100
- dummy = jnp.ones(X.shape[1], jnp.uint8)
101
- _, count_tree = mcmcstep.agg_values(X, var_tree, split_tree, dummy, dummy.dtype)
85
+ traverse_tree = jax.vmap(grove.traverse_tree, in_axes=(1, None, None))
86
+ indices = traverse_tree(X, var_tree, split_tree)
87
+ count_tree = jnp.zeros(2 * split_tree.size, dtype=jaxext.minimal_unsigned_dtype(indices.size))
88
+ count_tree = count_tree.at[indices].add(1)
102
89
  is_leaf = grove.is_actual_leaf(split_tree, add_bottom_level=True).view(jnp.uint8)
103
90
  return jnp.bincount(count_tree, is_leaf, length=X.shape[1] + 1)
104
91
 
@@ -117,7 +104,7 @@ def trace_points_per_leaf_distr(bart, X):
117
104
  return distr
118
105
 
119
106
  def check_types(leaf_tree, var_tree, split_tree, max_split):
120
- expected_var_dtype = grove.minimal_unsigned_dtype(max_split.size - 1)
107
+ expected_var_dtype = jaxext.minimal_unsigned_dtype(max_split.size - 1)
121
108
  expected_split_dtype = max_split.dtype
122
109
  return var_tree.dtype == expected_var_dtype and split_tree.dtype == expected_split_dtype
123
110
 
@@ -125,13 +112,13 @@ def check_sizes(leaf_tree, var_tree, split_tree, max_split):
125
112
  return leaf_tree.size == 2 * var_tree.size == 2 * split_tree.size
126
113
 
127
114
  def check_unused_node(leaf_tree, var_tree, split_tree, max_split):
128
- return (leaf_tree[0] == 0) & (var_tree[0] == 0) & (split_tree[0] == 0)
115
+ return (var_tree[0] == 0) & (split_tree[0] == 0)
129
116
 
130
117
  def check_leaf_values(leaf_tree, var_tree, split_tree, max_split):
131
118
  return jnp.all(jnp.isfinite(leaf_tree))
132
119
 
133
120
  def check_stray_nodes(leaf_tree, var_tree, split_tree, max_split):
134
- index = jnp.arange(2 * split_tree.size, dtype=grove.minimal_unsigned_dtype(2 * split_tree.size - 1))
121
+ index = jnp.arange(2 * split_tree.size, dtype=jaxext.minimal_unsigned_dtype(2 * split_tree.size - 1))
135
122
  parent_index = index >> 1
136
123
  is_not_leaf = split_tree.at[index].get(mode='fill', fill_value=0) != 0
137
124
  parent_is_leaf = split_tree[parent_index] == 0
@@ -148,7 +135,7 @@ check_functions = [
148
135
  ]
149
136
 
150
137
  def check_tree(leaf_tree, var_tree, split_tree, max_split):
151
- error_type = grove.minimal_unsigned_dtype(2 ** len(check_functions) - 1)
138
+ error_type = jaxext.minimal_unsigned_dtype(2 ** len(check_functions) - 1)
152
139
  error = error_type(0)
153
140
  for i, func in enumerate(check_functions):
154
141
  ok = func(leaf_tree, var_tree, split_tree, max_split)
bartz/grove.py CHANGED
@@ -28,13 +28,15 @@ Functions to create and manipulate binary trees.
28
28
 
29
29
  A tree is represented with arrays as a heap. The root node is at index 1. The children nodes of a node at index :math:`i` are at indices :math:`2i` (left child) and :math:`2i + 1` (right child). The array element at index 0 is unused.
30
30
 
31
- A decision tree is represented by tree arrays: 'leaf', 'var', and 'split'. The 'leaf' array contains the values in the leaves. The 'var' array contains the axes along which the decision nodes operate. The 'split' array contains the decision boundaries.
31
+ A decision tree is represented by tree arrays: 'leaf', 'var', and 'split'.
32
32
 
33
- Whether a node is a leaf is indicated by the corresponding 'split' element being 0.
33
+ The 'leaf' array contains the values in the leaves.
34
34
 
35
- Since the nodes at the bottom can only be leaves and not decision nodes, the 'var' and 'split' arrays have half the length of the 'leaf' array.
35
+ The 'var' array contains the axes along which the decision nodes operate.
36
+
37
+ The 'split' array contains the decision boundaries. The boundaries are open on the right, i.e., a point belongs to the left child iff x < split. Whether a node is a leaf is indicated by the corresponding 'split' element being 0.
36
38
 
37
- The unused array element at index 0 is always fixed to 0 by convention.
39
+ Since the nodes at the bottom can only be leaves and not decision nodes, the 'var' and 'split' arrays have half the length of the 'leaf' array.
38
40
 
39
41
  """
40
42
 
@@ -63,24 +65,18 @@ def make_tree(depth, dtype):
63
65
  -------
64
66
  tree : array
65
67
  An array of zeroes with shape (2 ** depth,).
66
-
67
- Notes
68
- -----
69
- The tree is represented as a heap, with the root node at index 1, and the
70
- children of the node at index i at indices 2 * i and 2 * i + 1. The element
71
- at index 0 is unused.
72
68
  """
73
69
  return jnp.zeros(2 ** depth, dtype)
74
70
 
75
71
  def tree_depth(tree):
76
72
  """
77
- Return the maximum depth of a binary tree created by `make_tree`.
73
+ Return the maximum depth of a tree.
78
74
 
79
75
  Parameters
80
76
  ----------
81
77
  tree : array
82
- A binary tree created by `make_tree`. If the array is ND, the tree
83
- structure is assumed to be along the last axis.
78
+ A tree created by `make_tree`. If the array is ND, the tree structure is
79
+ assumed to be along the last axis.
84
80
 
85
81
  Returns
86
82
  -------
@@ -89,120 +85,98 @@ def tree_depth(tree):
89
85
  """
90
86
  return int(round(math.log2(tree.shape[-1])))
91
87
 
92
- def evaluate_tree(X, leaf_trees, var_trees, split_trees, out_dtype):
88
+ def traverse_tree(x, var_tree, split_tree):
93
89
  """
94
- Evaluate a decision tree or forest.
90
+ Find the leaf where a point falls into.
95
91
 
96
92
  Parameters
97
93
  ----------
98
- X : array (p,)
94
+ x : array (p,)
99
95
  The coordinates to evaluate the tree at.
100
- leaf_trees : array (n,) or (m, n)
101
- The leaf values of the tree or forest. If the input is a forest, the
102
- first axis is the tree index, and the values are summed.
103
- var_trees : array (n,) or (m, n)
104
- The variable indices of the tree or forest. Each index is in [0, p) and
105
- indicates which value of `X` to consider.
106
- split_trees : array (n,) or (m, n)
107
- The split values of the tree or forest. Leaf nodes are indicated by the
108
- condition `split == 0`. If non-zero, the node has children, and its left
109
- children is assigned points which satisfy `x < split`.
110
- out_dtype : dtype
111
- The dtype of the output.
96
+ var_tree : array (2 ** (d - 1),)
97
+ The decision axes of the tree.
98
+ split_tree : array (2 ** (d - 1),)
99
+ The decision boundaries of the tree.
112
100
 
113
101
  Returns
114
102
  -------
115
- out : scalar
116
- The value of the tree or forest at the given point.
103
+ index : int
104
+ The index of the leaf.
117
105
  """
118
106
 
119
- is_forest = leaf_trees.ndim == 2
120
- if is_forest:
121
- m, _ = leaf_trees.shape
122
- forest_shape = m,
123
- tree_index = jnp.arange(m, dtype=minimal_unsigned_dtype(m - 1)),
124
- else:
125
- forest_shape = ()
126
- tree_index = ()
127
-
128
107
  carry = (
129
- jnp.zeros(forest_shape, bool),
130
- jnp.zeros((), out_dtype),
131
- jnp.ones(forest_shape, minimal_unsigned_dtype(leaf_trees.shape[-1] - 1))
108
+ jnp.zeros((), bool),
109
+ jnp.ones((), jaxext.minimal_unsigned_dtype(2 * var_tree.size - 1)),
132
110
  )
133
111
 
134
112
  def loop(carry, _):
135
- leaf_found, out, node_index = carry
136
-
137
- is_leaf = split_trees.at[tree_index + (node_index,)].get(mode='fill', fill_value=0) == 0
138
- leaf_value = leaf_trees[tree_index + (node_index,)]
139
- if is_forest:
140
- leaf_sum = jnp.sum(leaf_value, where=is_leaf) # TODO set dtype to large float
141
- # alternative: dot(is_leaf, leaf_value):
142
- # - maybe faster
143
- # - maybe less accurate
144
- # - fucked by nans
145
- else:
146
- leaf_sum = jnp.where(is_leaf, leaf_value, 0)
147
- out += leaf_sum
148
- leaf_found |= is_leaf
149
-
150
- split = split_trees.at[tree_index + (node_index,)].get(mode='fill', fill_value=0)
151
- var = var_trees.at[tree_index + (node_index,)].get(mode='fill', fill_value=0)
152
- x = X[var]
113
+ leaf_found, index = carry
114
+
115
+ split = split_tree[index]
116
+ var = var_tree[index]
153
117
 
154
- node_index <<= 1
155
- node_index += x >= split
156
- node_index = jnp.where(leaf_found, 0, node_index)
118
+ leaf_found |= split == 0
119
+ child_index = (index << 1) + (x[var] >= split)
120
+ index = jnp.where(leaf_found, index, child_index)
157
121
 
158
- carry = leaf_found, out, node_index
159
- return carry, _
122
+ return (leaf_found, index), None
160
123
 
161
- depth = tree_depth(leaf_trees)
162
- (_, out, _), _ = lax.scan(loop, carry, None, depth)
163
- return out
124
+ depth = tree_depth(var_tree)
125
+ (_, index), _ = lax.scan(loop, carry, None, depth, unroll=16)
126
+ return index
164
127
 
165
- def minimal_unsigned_dtype(max_value):
128
+ @functools.partial(jaxext.vmap_nodoc, in_axes=(None, 0, 0))
129
+ @functools.partial(jaxext.vmap_nodoc, in_axes=(1, None, None))
130
+ def traverse_forest(X, var_trees, split_trees):
166
131
  """
167
- Return the smallest unsigned integer dtype that can represent a given
168
- maximum value.
132
+ Find the leaves where points fall into.
133
+
134
+ Parameters
135
+ ----------
136
+ X : array (p, n)
137
+ The coordinates to evaluate the trees at.
138
+ var_trees : array (m, 2 ** (d - 1))
139
+ The decision axes of the trees.
140
+ split_trees : array (m, 2 ** (d - 1))
141
+ The decision boundaries of the trees.
142
+
143
+ Returns
144
+ -------
145
+ indices : array (m, n)
146
+ The indices of the leaves.
169
147
  """
170
- if max_value < 2 ** 8:
171
- return jnp.uint8
172
- if max_value < 2 ** 16:
173
- return jnp.uint16
174
- if max_value < 2 ** 32:
175
- return jnp.uint32
176
- return jnp.uint64
177
-
178
- @functools.partial(jaxext.vmap_nodoc, in_axes=(1, None, None, None, None), out_axes=0)
179
- def evaluate_tree_vmap_x(X, leaf_trees, var_trees, split_trees, out_dtype):
148
+ return traverse_tree(X, var_trees, split_trees)
149
+
150
+ def evaluate_forest(X, leaf_trees, var_trees, split_trees, dtype):
180
151
  """
181
- Evaluate a decision tree or forest over multiple points.
152
+ Evaluate a ensemble of trees at an array of points.
182
153
 
183
154
  Parameters
184
155
  ----------
185
156
  X : array (p, n)
186
- The points to evaluate the tree at.
187
- leaf_trees : array (n,) or (m, n)
157
+ The coordinates to evaluate the trees at.
158
+ leaf_trees : array (m, 2 ** d)
188
159
  The leaf values of the tree or forest. If the input is a forest, the
189
160
  first axis is the tree index, and the values are summed.
190
- var_trees : array (n,) or (m, n)
191
- The variable indices of the tree or forest. Each index is in [0, p) and
192
- indicates which value of `X` to consider.
193
- split_trees : array (n,) or (m, n)
194
- The split values of the tree or forest. Leaf nodes are indicated by the
195
- condition `split == 0`. If non-zero, the node has children, and its left
196
- children is assigned points which satisfy `x < split`.
197
- out_dtype : dtype
161
+ var_trees : array (m, 2 ** (d - 1))
162
+ The decision axes of the trees.
163
+ split_trees : array (m, 2 ** (d - 1))
164
+ The decision boundaries of the trees.
165
+ dtype : dtype
198
166
  The dtype of the output.
199
167
 
200
168
  Returns
201
169
  -------
202
- out : (n,)
203
- The value of the tree or forest at each point.
170
+ out : array (n,)
171
+ The sum of the values of the trees at the points in `X`.
204
172
  """
205
- return evaluate_tree(X, leaf_trees, var_trees, split_trees, out_dtype)
173
+ indices = traverse_forest(X, var_trees, split_trees)
174
+ ntree, _ = leaf_trees.shape
175
+ tree_index = jnp.arange(ntree, dtype=jaxext.minimal_unsigned_dtype(ntree - 1))[:, None]
176
+ leaves = leaf_trees[tree_index, indices]
177
+ return jnp.sum(leaves, axis=0, dtype=dtype)
178
+ # this sum suggests to swap the vmaps, but I think it's better for X
179
+ # copying to keep it that way
206
180
 
207
181
  def is_actual_leaf(split_tree, *, add_bottom_level=False):
208
182
  """
@@ -226,7 +200,7 @@ def is_actual_leaf(split_tree, *, add_bottom_level=False):
226
200
  if add_bottom_level:
227
201
  size *= 2
228
202
  is_leaf = jnp.concatenate([is_leaf, jnp.ones_like(is_leaf)])
229
- index = jnp.arange(size, dtype=minimal_unsigned_dtype(size - 1))
203
+ index = jnp.arange(size, dtype=jaxext.minimal_unsigned_dtype(size - 1))
230
204
  parent_index = index >> 1
231
205
  parent_nonleaf = split_tree[parent_index].astype(bool)
232
206
  parent_nonleaf = parent_nonleaf.at[1].set(True)
@@ -239,14 +213,14 @@ def is_leaves_parent(split_tree):
239
213
  Parameters
240
214
  ----------
241
215
  split_tree : int array (2 ** (d - 1),)
242
- The splitting points of the tree.
216
+ The decision boundaries of the tree.
243
217
 
244
218
  Returns
245
219
  -------
246
220
  is_leaves_parent : bool array (2 ** (d - 1),)
247
221
  The mask indicating which nodes have leaf children.
248
222
  """
249
- index = jnp.arange(split_tree.size, dtype=minimal_unsigned_dtype(2 * split_tree.size - 1))
223
+ index = jnp.arange(split_tree.size, dtype=jaxext.minimal_unsigned_dtype(2 * split_tree.size - 1))
250
224
  left_index = index << 1 # left child
251
225
  right_index = left_index + 1 # right child
252
226
  left_leaf = split_tree.at[left_index].get(mode='fill', fill_value=0) == 0
@@ -278,25 +252,4 @@ def tree_depths(tree_length):
278
252
  depth += 1
279
253
  depths.append(depth - 1)
280
254
  depths[0] = 0
281
- return jnp.array(depths, minimal_unsigned_dtype(max(depths)))
282
-
283
- def index_depth(index, tree_length):
284
- """
285
- Return the depth of a node in a binary tree.
286
-
287
- Parameters
288
- ----------
289
- index : int
290
- The index of the node.
291
- tree_length : int
292
- The length of the tree array, i.e., 2 ** d.
293
-
294
- Returns
295
- -------
296
- depth : int
297
- The depth of the node. The root node (index 1) has depth 0. The depth is
298
- the position of the most significant non-zero bit in the index. If
299
- ``index == 0``, return -1.
300
- """
301
- depths = tree_depths(tree_length)
302
- return depths[index]
255
+ return jnp.array(depths, jaxext.minimal_unsigned_dtype(max(depths)))