bartz 0.0.1__tar.gz → 0.2.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {bartz-0.0.1 → bartz-0.2.0}/PKG-INFO +7 -1
- bartz-0.2.0/README.md +9 -0
- {bartz-0.0.1 → bartz-0.2.0}/pyproject.toml +4 -2
- bartz-0.0.1/src/bartz/interface.py → bartz-0.2.0/src/bartz/BART.py +10 -18
- {bartz-0.0.1 → bartz-0.2.0}/src/bartz/__init__.py +7 -2
- bartz-0.2.0/src/bartz/_version.py +1 -0
- {bartz-0.0.1 → bartz-0.2.0}/src/bartz/debug.py +9 -22
- {bartz-0.0.1 → bartz-0.2.0}/src/bartz/grove.py +73 -120
- bartz-0.2.0/src/bartz/jaxext.py +341 -0
- {bartz-0.0.1 → bartz-0.2.0}/src/bartz/mcmcloop.py +27 -13
- {bartz-0.0.1 → bartz-0.2.0}/src/bartz/mcmcstep.py +510 -439
- {bartz-0.0.1 → bartz-0.2.0}/src/bartz/prepcovars.py +25 -30
- bartz-0.0.1/README.md +0 -3
- bartz-0.0.1/src/bartz/jaxext.py +0 -85
- {bartz-0.0.1 → bartz-0.2.0}/LICENSE +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: bartz
|
|
3
|
-
Version: 0.0
|
|
3
|
+
Version: 0.2.0
|
|
4
4
|
Summary: A JAX implementation of BART
|
|
5
5
|
Home-page: https://github.com/Gattocrucco/bartz
|
|
6
6
|
License: MIT
|
|
@@ -20,7 +20,13 @@ Project-URL: Bug Tracker, https://github.com/Gattocrucco/bartz/issues
|
|
|
20
20
|
Project-URL: Repository, https://github.com/Gattocrucco/bartz
|
|
21
21
|
Description-Content-Type: text/markdown
|
|
22
22
|
|
|
23
|
+
[](https://pypi.org/project/bartz/)
|
|
24
|
+
|
|
23
25
|
# BART vectoriZed
|
|
24
26
|
|
|
25
27
|
A branchless vectorized implementation of Bayesian Additive Regression Trees (BART) in JAX.
|
|
26
28
|
|
|
29
|
+
BART is a nonparametric Bayesian regression technique. Given predictors $X$ and responses $y$, BART finds a function to predict $y$ given $X$. The result of the inference is a sample of possible functions, representing the uncertainty over the determination of the function.
|
|
30
|
+
|
|
31
|
+
This Python module provides an implementation of BART that runs on GPU, to process large datasets faster. It is also a good on CPU. Most other implementations of BART are for R, and run on CPU only.
|
|
32
|
+
|
bartz-0.2.0/README.md
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
[](https://pypi.org/project/bartz/)
|
|
2
|
+
|
|
3
|
+
# BART vectoriZed
|
|
4
|
+
|
|
5
|
+
A branchless vectorized implementation of Bayesian Additive Regression Trees (BART) in JAX.
|
|
6
|
+
|
|
7
|
+
BART is a nonparametric Bayesian regression technique. Given predictors $X$ and responses $y$, BART finds a function to predict $y$ given $X$. The result of the inference is a sample of possible functions, representing the uncertainty over the determination of the function.
|
|
8
|
+
|
|
9
|
+
This Python module provides an implementation of BART that runs on GPU, to process large datasets faster. It is also a good on CPU. Most other implementations of BART are for R, and run on CPU only.
|
|
@@ -28,7 +28,7 @@ build-backend = "poetry.core.masonry.api"
|
|
|
28
28
|
|
|
29
29
|
[tool.poetry]
|
|
30
30
|
name = "bartz"
|
|
31
|
-
version = "0.0
|
|
31
|
+
version = "0.2.0"
|
|
32
32
|
description = "A JAX implementation of BART"
|
|
33
33
|
authors = ["Giacomo Petrillo <info@giacomopetrillo.com>"]
|
|
34
34
|
license = "MIT"
|
|
@@ -52,6 +52,8 @@ scipy = "^1.11.4"
|
|
|
52
52
|
ipython = "^8.22.2"
|
|
53
53
|
matplotlib = "^3.8.3"
|
|
54
54
|
appnope = "^0.1.4"
|
|
55
|
+
tomli = "^2.0.1"
|
|
56
|
+
packaging = "^24.0"
|
|
55
57
|
|
|
56
58
|
[tool.poetry.group.test.dependencies]
|
|
57
59
|
coverage = "^7.4.3"
|
|
@@ -59,7 +61,7 @@ pytest = "^8.1.1"
|
|
|
59
61
|
|
|
60
62
|
[tool.poetry.group.docs.dependencies]
|
|
61
63
|
Sphinx = "^7.2.6"
|
|
62
|
-
numpydoc = "^1.6.0"
|
|
64
|
+
numpydoc = "^1.6.0,<1.7.0" # 1.7.0 breaks linkcode, it seems
|
|
63
65
|
myst-parser = "^2.0.0"
|
|
64
66
|
|
|
65
67
|
[tool.pytest.ini_options]
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# bartz/src/bartz/
|
|
1
|
+
# bartz/src/bartz/BART.py
|
|
2
2
|
#
|
|
3
3
|
# Copyright (c) 2024, Giacomo Petrillo
|
|
4
4
|
#
|
|
@@ -33,12 +33,12 @@ from . import mcmcstep
|
|
|
33
33
|
from . import mcmcloop
|
|
34
34
|
from . import prepcovars
|
|
35
35
|
|
|
36
|
-
class
|
|
36
|
+
class gbart:
|
|
37
37
|
"""
|
|
38
38
|
Nonparametric regression with Bayesian Additive Regression Trees (BART).
|
|
39
39
|
|
|
40
40
|
Regress `y_train` on `x_train` with a latent mean function represented as
|
|
41
|
-
a sum of decision trees. The inference is carried out by
|
|
41
|
+
a sum of decision trees. The inference is carried out by sampling the
|
|
42
42
|
posterior distribution of the tree ensemble with an MCMC.
|
|
43
43
|
|
|
44
44
|
Parameters
|
|
@@ -86,7 +86,7 @@ class BART:
|
|
|
86
86
|
predictor is binned such that its distribution in `x_train` is
|
|
87
87
|
approximately uniform across bins. The number of bins is at most the
|
|
88
88
|
number of unique values appearing in `x_train`, or ``numcut + 1``.
|
|
89
|
-
Before running the algorithm, the predictors are compressed to
|
|
89
|
+
Before running the algorithm, the predictors are compressed to the
|
|
90
90
|
smallest integer type that fits the bin indices, so `numcut` is best set
|
|
91
91
|
to the maximum value of an unsigned integer type.
|
|
92
92
|
ndpost : int, default 1000
|
|
@@ -133,7 +133,7 @@ class BART:
|
|
|
133
133
|
|
|
134
134
|
Notes
|
|
135
135
|
-----
|
|
136
|
-
This interface imitates the function `
|
|
136
|
+
This interface imitates the function `gbart` from the R package `BART
|
|
137
137
|
<https://cran.r-project.org/package=BART>`_, but with these differences:
|
|
138
138
|
|
|
139
139
|
- If `x_train` and `x_test` are matrices, they have one predictor per row
|
|
@@ -142,6 +142,7 @@ class BART:
|
|
|
142
142
|
- `usequants` is always `True`.
|
|
143
143
|
- `rm_const` is always `False`.
|
|
144
144
|
- The default `numcut` is 255 instead of 100.
|
|
145
|
+
- A lot of functionality is missing (variable selection, discrete response).
|
|
145
146
|
- There are some additional attributes, and some missing.
|
|
146
147
|
"""
|
|
147
148
|
|
|
@@ -321,7 +322,7 @@ class BART:
|
|
|
321
322
|
p_nonterminal = base / (1 + depth).astype(float) ** power
|
|
322
323
|
sigma2_alpha = sigdf / 2
|
|
323
324
|
sigma2_beta = lamda * sigma2_alpha
|
|
324
|
-
return mcmcstep.
|
|
325
|
+
return mcmcstep.init(
|
|
325
326
|
X=x_train,
|
|
326
327
|
y=y_train,
|
|
327
328
|
max_split=max_split,
|
|
@@ -354,13 +355,6 @@ class BART:
|
|
|
354
355
|
return scale * jnp.sqrt(trace['sigma2'])
|
|
355
356
|
|
|
356
357
|
|
|
357
|
-
def _predict_debug(self, x_test):
|
|
358
|
-
from . import debug
|
|
359
|
-
x_test, x_test_fmt = self._process_predictor_input(x_test)
|
|
360
|
-
self._check_compatible_formats(x_test_fmt, self._x_train_fmt)
|
|
361
|
-
x_test = self._bin_predictors(x_test, self._splits)
|
|
362
|
-
return debug.trace_evaluate_trees(self._main_trace, x_test)
|
|
363
|
-
|
|
364
358
|
def _show_tree(self, i_sample, i_tree, print_all=False):
|
|
365
359
|
from . import debug
|
|
366
360
|
trace = self._main_trace
|
|
@@ -385,7 +379,7 @@ class BART:
|
|
|
385
379
|
def _compare_resid(self):
|
|
386
380
|
bart = self._mcmc_state
|
|
387
381
|
resid1 = bart['resid']
|
|
388
|
-
yhat = grove.
|
|
382
|
+
yhat = grove.evaluate_forest(bart['X'], bart['leaf_trees'], bart['var_trees'], bart['split_trees'], jnp.float32)
|
|
389
383
|
resid2 = bart['y'] - yhat
|
|
390
384
|
return resid1, resid2
|
|
391
385
|
|
|
@@ -427,7 +421,5 @@ class BART:
|
|
|
427
421
|
|
|
428
422
|
def _tree_goes_bad(self):
|
|
429
423
|
bad = self._check_trees().astype(bool)
|
|
430
|
-
bad_before = bad[:-1]
|
|
431
|
-
|
|
432
|
-
goes_bad = bad_after & ~bad_before
|
|
433
|
-
return jnp.pad(goes_bad, [(1, 0), (0, 0)])
|
|
424
|
+
bad_before = jnp.pad(bad[:-1], [(1, 0), (0, 0)])
|
|
425
|
+
return bad & ~bad_before
|
|
@@ -28,8 +28,13 @@ A jax implementation of BART
|
|
|
28
28
|
See the manual at https://gattocrucco.github.io/bartz/docs
|
|
29
29
|
"""
|
|
30
30
|
|
|
31
|
-
|
|
31
|
+
from ._version import __version__
|
|
32
32
|
|
|
33
|
-
from .
|
|
33
|
+
from . import BART
|
|
34
34
|
|
|
35
35
|
from . import debug
|
|
36
|
+
from . import grove
|
|
37
|
+
from . import mcmcstep
|
|
38
|
+
from . import mcmcloop
|
|
39
|
+
from . import prepcovars
|
|
40
|
+
from . import jaxext
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = '0.2.0'
|
|
@@ -6,22 +6,7 @@ from jax import lax
|
|
|
6
6
|
|
|
7
7
|
from . import grove
|
|
8
8
|
from . import mcmcstep
|
|
9
|
-
|
|
10
|
-
def trace_evaluate_trees(bart, X):
|
|
11
|
-
"""
|
|
12
|
-
Evaluate all trees, for all samples, at all x. Out axes:
|
|
13
|
-
0: mcmc sample
|
|
14
|
-
1: tree
|
|
15
|
-
2: X
|
|
16
|
-
"""
|
|
17
|
-
def loop(_, bart):
|
|
18
|
-
return None, evaluate_all_trees(X, bart['leaf_trees'], bart['var_trees'], bart['split_trees'])
|
|
19
|
-
_, y = lax.scan(loop, None, bart)
|
|
20
|
-
return y
|
|
21
|
-
|
|
22
|
-
@functools.partial(jax.vmap, in_axes=(None, 0, 0, 0)) # vectorize over forest
|
|
23
|
-
def evaluate_all_trees(X, leaf_trees, var_trees, split_trees):
|
|
24
|
-
return grove.evaluate_tree_vmap_x(X, leaf_trees, var_trees, split_trees, jnp.float32)
|
|
9
|
+
from . import jaxext
|
|
25
10
|
|
|
26
11
|
def print_tree(leaf_tree, var_tree, split_tree, print_all=False):
|
|
27
12
|
|
|
@@ -97,8 +82,10 @@ def trace_depth_distr(split_trees_trace):
|
|
|
97
82
|
return jax.vmap(forest_depth_distr)(split_trees_trace)
|
|
98
83
|
|
|
99
84
|
def points_per_leaf_distr(var_tree, split_tree, X):
|
|
100
|
-
|
|
101
|
-
|
|
85
|
+
traverse_tree = jax.vmap(grove.traverse_tree, in_axes=(1, None, None))
|
|
86
|
+
indices = traverse_tree(X, var_tree, split_tree)
|
|
87
|
+
count_tree = jnp.zeros(2 * split_tree.size, dtype=jaxext.minimal_unsigned_dtype(indices.size))
|
|
88
|
+
count_tree = count_tree.at[indices].add(1)
|
|
102
89
|
is_leaf = grove.is_actual_leaf(split_tree, add_bottom_level=True).view(jnp.uint8)
|
|
103
90
|
return jnp.bincount(count_tree, is_leaf, length=X.shape[1] + 1)
|
|
104
91
|
|
|
@@ -117,7 +104,7 @@ def trace_points_per_leaf_distr(bart, X):
|
|
|
117
104
|
return distr
|
|
118
105
|
|
|
119
106
|
def check_types(leaf_tree, var_tree, split_tree, max_split):
|
|
120
|
-
expected_var_dtype =
|
|
107
|
+
expected_var_dtype = jaxext.minimal_unsigned_dtype(max_split.size - 1)
|
|
121
108
|
expected_split_dtype = max_split.dtype
|
|
122
109
|
return var_tree.dtype == expected_var_dtype and split_tree.dtype == expected_split_dtype
|
|
123
110
|
|
|
@@ -125,13 +112,13 @@ def check_sizes(leaf_tree, var_tree, split_tree, max_split):
|
|
|
125
112
|
return leaf_tree.size == 2 * var_tree.size == 2 * split_tree.size
|
|
126
113
|
|
|
127
114
|
def check_unused_node(leaf_tree, var_tree, split_tree, max_split):
|
|
128
|
-
return (
|
|
115
|
+
return (var_tree[0] == 0) & (split_tree[0] == 0)
|
|
129
116
|
|
|
130
117
|
def check_leaf_values(leaf_tree, var_tree, split_tree, max_split):
|
|
131
118
|
return jnp.all(jnp.isfinite(leaf_tree))
|
|
132
119
|
|
|
133
120
|
def check_stray_nodes(leaf_tree, var_tree, split_tree, max_split):
|
|
134
|
-
index = jnp.arange(2 * split_tree.size, dtype=
|
|
121
|
+
index = jnp.arange(2 * split_tree.size, dtype=jaxext.minimal_unsigned_dtype(2 * split_tree.size - 1))
|
|
135
122
|
parent_index = index >> 1
|
|
136
123
|
is_not_leaf = split_tree.at[index].get(mode='fill', fill_value=0) != 0
|
|
137
124
|
parent_is_leaf = split_tree[parent_index] == 0
|
|
@@ -148,7 +135,7 @@ check_functions = [
|
|
|
148
135
|
]
|
|
149
136
|
|
|
150
137
|
def check_tree(leaf_tree, var_tree, split_tree, max_split):
|
|
151
|
-
error_type =
|
|
138
|
+
error_type = jaxext.minimal_unsigned_dtype(2 ** len(check_functions) - 1)
|
|
152
139
|
error = error_type(0)
|
|
153
140
|
for i, func in enumerate(check_functions):
|
|
154
141
|
ok = func(leaf_tree, var_tree, split_tree, max_split)
|
|
@@ -28,13 +28,15 @@ Functions to create and manipulate binary trees.
|
|
|
28
28
|
|
|
29
29
|
A tree is represented with arrays as a heap. The root node is at index 1. The children nodes of a node at index :math:`i` are at indices :math:`2i` (left child) and :math:`2i + 1` (right child). The array element at index 0 is unused.
|
|
30
30
|
|
|
31
|
-
A decision tree is represented by tree arrays: 'leaf', 'var', and 'split'.
|
|
31
|
+
A decision tree is represented by tree arrays: 'leaf', 'var', and 'split'.
|
|
32
32
|
|
|
33
|
-
|
|
33
|
+
The 'leaf' array contains the values in the leaves.
|
|
34
34
|
|
|
35
|
-
|
|
35
|
+
The 'var' array contains the axes along which the decision nodes operate.
|
|
36
|
+
|
|
37
|
+
The 'split' array contains the decision boundaries. The boundaries are open on the right, i.e., a point belongs to the left child iff x < split. Whether a node is a leaf is indicated by the corresponding 'split' element being 0.
|
|
36
38
|
|
|
37
|
-
|
|
39
|
+
Since the nodes at the bottom can only be leaves and not decision nodes, the 'var' and 'split' arrays have half the length of the 'leaf' array.
|
|
38
40
|
|
|
39
41
|
"""
|
|
40
42
|
|
|
@@ -63,24 +65,18 @@ def make_tree(depth, dtype):
|
|
|
63
65
|
-------
|
|
64
66
|
tree : array
|
|
65
67
|
An array of zeroes with shape (2 ** depth,).
|
|
66
|
-
|
|
67
|
-
Notes
|
|
68
|
-
-----
|
|
69
|
-
The tree is represented as a heap, with the root node at index 1, and the
|
|
70
|
-
children of the node at index i at indices 2 * i and 2 * i + 1. The element
|
|
71
|
-
at index 0 is unused.
|
|
72
68
|
"""
|
|
73
69
|
return jnp.zeros(2 ** depth, dtype)
|
|
74
70
|
|
|
75
71
|
def tree_depth(tree):
|
|
76
72
|
"""
|
|
77
|
-
Return the maximum depth of a
|
|
73
|
+
Return the maximum depth of a tree.
|
|
78
74
|
|
|
79
75
|
Parameters
|
|
80
76
|
----------
|
|
81
77
|
tree : array
|
|
82
|
-
A
|
|
83
|
-
|
|
78
|
+
A tree created by `make_tree`. If the array is ND, the tree structure is
|
|
79
|
+
assumed to be along the last axis.
|
|
84
80
|
|
|
85
81
|
Returns
|
|
86
82
|
-------
|
|
@@ -89,120 +85,98 @@ def tree_depth(tree):
|
|
|
89
85
|
"""
|
|
90
86
|
return int(round(math.log2(tree.shape[-1])))
|
|
91
87
|
|
|
92
|
-
def
|
|
88
|
+
def traverse_tree(x, var_tree, split_tree):
|
|
93
89
|
"""
|
|
94
|
-
|
|
90
|
+
Find the leaf where a point falls into.
|
|
95
91
|
|
|
96
92
|
Parameters
|
|
97
93
|
----------
|
|
98
|
-
|
|
94
|
+
x : array (p,)
|
|
99
95
|
The coordinates to evaluate the tree at.
|
|
100
|
-
|
|
101
|
-
The
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
The variable indices of the tree or forest. Each index is in [0, p) and
|
|
105
|
-
indicates which value of `X` to consider.
|
|
106
|
-
split_trees : array (n,) or (m, n)
|
|
107
|
-
The split values of the tree or forest. Leaf nodes are indicated by the
|
|
108
|
-
condition `split == 0`. If non-zero, the node has children, and its left
|
|
109
|
-
children is assigned points which satisfy `x < split`.
|
|
110
|
-
out_dtype : dtype
|
|
111
|
-
The dtype of the output.
|
|
96
|
+
var_tree : array (2 ** (d - 1),)
|
|
97
|
+
The decision axes of the tree.
|
|
98
|
+
split_tree : array (2 ** (d - 1),)
|
|
99
|
+
The decision boundaries of the tree.
|
|
112
100
|
|
|
113
101
|
Returns
|
|
114
102
|
-------
|
|
115
|
-
|
|
116
|
-
The
|
|
103
|
+
index : int
|
|
104
|
+
The index of the leaf.
|
|
117
105
|
"""
|
|
118
106
|
|
|
119
|
-
is_forest = leaf_trees.ndim == 2
|
|
120
|
-
if is_forest:
|
|
121
|
-
m, _ = leaf_trees.shape
|
|
122
|
-
forest_shape = m,
|
|
123
|
-
tree_index = jnp.arange(m, dtype=minimal_unsigned_dtype(m - 1)),
|
|
124
|
-
else:
|
|
125
|
-
forest_shape = ()
|
|
126
|
-
tree_index = ()
|
|
127
|
-
|
|
128
107
|
carry = (
|
|
129
|
-
jnp.zeros(
|
|
130
|
-
jnp.
|
|
131
|
-
jnp.ones(forest_shape, minimal_unsigned_dtype(leaf_trees.shape[-1] - 1))
|
|
108
|
+
jnp.zeros((), bool),
|
|
109
|
+
jnp.ones((), jaxext.minimal_unsigned_dtype(2 * var_tree.size - 1)),
|
|
132
110
|
)
|
|
133
111
|
|
|
134
112
|
def loop(carry, _):
|
|
135
|
-
leaf_found,
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
if is_forest:
|
|
140
|
-
leaf_sum = jnp.sum(leaf_value, where=is_leaf) # TODO set dtype to large float
|
|
141
|
-
# alternative: dot(is_leaf, leaf_value):
|
|
142
|
-
# - maybe faster
|
|
143
|
-
# - maybe less accurate
|
|
144
|
-
# - fucked by nans
|
|
145
|
-
else:
|
|
146
|
-
leaf_sum = jnp.where(is_leaf, leaf_value, 0)
|
|
147
|
-
out += leaf_sum
|
|
148
|
-
leaf_found |= is_leaf
|
|
149
|
-
|
|
150
|
-
split = split_trees.at[tree_index + (node_index,)].get(mode='fill', fill_value=0)
|
|
151
|
-
var = var_trees.at[tree_index + (node_index,)].get(mode='fill', fill_value=0)
|
|
152
|
-
x = X[var]
|
|
113
|
+
leaf_found, index = carry
|
|
114
|
+
|
|
115
|
+
split = split_tree[index]
|
|
116
|
+
var = var_tree[index]
|
|
153
117
|
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
118
|
+
leaf_found |= split == 0
|
|
119
|
+
child_index = (index << 1) + (x[var] >= split)
|
|
120
|
+
index = jnp.where(leaf_found, index, child_index)
|
|
157
121
|
|
|
158
|
-
|
|
159
|
-
return carry, _
|
|
122
|
+
return (leaf_found, index), None
|
|
160
123
|
|
|
161
|
-
depth = tree_depth(
|
|
162
|
-
(_,
|
|
163
|
-
return
|
|
124
|
+
depth = tree_depth(var_tree)
|
|
125
|
+
(_, index), _ = lax.scan(loop, carry, None, depth, unroll=16)
|
|
126
|
+
return index
|
|
164
127
|
|
|
165
|
-
|
|
128
|
+
@functools.partial(jaxext.vmap_nodoc, in_axes=(None, 0, 0))
|
|
129
|
+
@functools.partial(jaxext.vmap_nodoc, in_axes=(1, None, None))
|
|
130
|
+
def traverse_forest(X, var_trees, split_trees):
|
|
166
131
|
"""
|
|
167
|
-
|
|
168
|
-
|
|
132
|
+
Find the leaves where points fall into.
|
|
133
|
+
|
|
134
|
+
Parameters
|
|
135
|
+
----------
|
|
136
|
+
X : array (p, n)
|
|
137
|
+
The coordinates to evaluate the trees at.
|
|
138
|
+
var_trees : array (m, 2 ** (d - 1))
|
|
139
|
+
The decision axes of the trees.
|
|
140
|
+
split_trees : array (m, 2 ** (d - 1))
|
|
141
|
+
The decision boundaries of the trees.
|
|
142
|
+
|
|
143
|
+
Returns
|
|
144
|
+
-------
|
|
145
|
+
indices : array (m, n)
|
|
146
|
+
The indices of the leaves.
|
|
169
147
|
"""
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
return jnp.uint16
|
|
174
|
-
if max_value < 2 ** 32:
|
|
175
|
-
return jnp.uint32
|
|
176
|
-
return jnp.uint64
|
|
177
|
-
|
|
178
|
-
@functools.partial(jaxext.vmap_nodoc, in_axes=(1, None, None, None, None), out_axes=0)
|
|
179
|
-
def evaluate_tree_vmap_x(X, leaf_trees, var_trees, split_trees, out_dtype):
|
|
148
|
+
return traverse_tree(X, var_trees, split_trees)
|
|
149
|
+
|
|
150
|
+
def evaluate_forest(X, leaf_trees, var_trees, split_trees, dtype):
|
|
180
151
|
"""
|
|
181
|
-
Evaluate a
|
|
152
|
+
Evaluate a ensemble of trees at an array of points.
|
|
182
153
|
|
|
183
154
|
Parameters
|
|
184
155
|
----------
|
|
185
156
|
X : array (p, n)
|
|
186
|
-
The
|
|
187
|
-
leaf_trees : array (
|
|
157
|
+
The coordinates to evaluate the trees at.
|
|
158
|
+
leaf_trees : array (m, 2 ** d)
|
|
188
159
|
The leaf values of the tree or forest. If the input is a forest, the
|
|
189
160
|
first axis is the tree index, and the values are summed.
|
|
190
|
-
var_trees : array (
|
|
191
|
-
The
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
condition `split == 0`. If non-zero, the node has children, and its left
|
|
196
|
-
children is assigned points which satisfy `x < split`.
|
|
197
|
-
out_dtype : dtype
|
|
161
|
+
var_trees : array (m, 2 ** (d - 1))
|
|
162
|
+
The decision axes of the trees.
|
|
163
|
+
split_trees : array (m, 2 ** (d - 1))
|
|
164
|
+
The decision boundaries of the trees.
|
|
165
|
+
dtype : dtype
|
|
198
166
|
The dtype of the output.
|
|
199
167
|
|
|
200
168
|
Returns
|
|
201
169
|
-------
|
|
202
|
-
out : (n,)
|
|
203
|
-
The
|
|
170
|
+
out : array (n,)
|
|
171
|
+
The sum of the values of the trees at the points in `X`.
|
|
204
172
|
"""
|
|
205
|
-
|
|
173
|
+
indices = traverse_forest(X, var_trees, split_trees)
|
|
174
|
+
ntree, _ = leaf_trees.shape
|
|
175
|
+
tree_index = jnp.arange(ntree, dtype=jaxext.minimal_unsigned_dtype(ntree - 1))[:, None]
|
|
176
|
+
leaves = leaf_trees[tree_index, indices]
|
|
177
|
+
return jnp.sum(leaves, axis=0, dtype=dtype)
|
|
178
|
+
# this sum suggests to swap the vmaps, but I think it's better for X
|
|
179
|
+
# copying to keep it that way
|
|
206
180
|
|
|
207
181
|
def is_actual_leaf(split_tree, *, add_bottom_level=False):
|
|
208
182
|
"""
|
|
@@ -226,7 +200,7 @@ def is_actual_leaf(split_tree, *, add_bottom_level=False):
|
|
|
226
200
|
if add_bottom_level:
|
|
227
201
|
size *= 2
|
|
228
202
|
is_leaf = jnp.concatenate([is_leaf, jnp.ones_like(is_leaf)])
|
|
229
|
-
index = jnp.arange(size, dtype=minimal_unsigned_dtype(size - 1))
|
|
203
|
+
index = jnp.arange(size, dtype=jaxext.minimal_unsigned_dtype(size - 1))
|
|
230
204
|
parent_index = index >> 1
|
|
231
205
|
parent_nonleaf = split_tree[parent_index].astype(bool)
|
|
232
206
|
parent_nonleaf = parent_nonleaf.at[1].set(True)
|
|
@@ -239,14 +213,14 @@ def is_leaves_parent(split_tree):
|
|
|
239
213
|
Parameters
|
|
240
214
|
----------
|
|
241
215
|
split_tree : int array (2 ** (d - 1),)
|
|
242
|
-
The
|
|
216
|
+
The decision boundaries of the tree.
|
|
243
217
|
|
|
244
218
|
Returns
|
|
245
219
|
-------
|
|
246
220
|
is_leaves_parent : bool array (2 ** (d - 1),)
|
|
247
221
|
The mask indicating which nodes have leaf children.
|
|
248
222
|
"""
|
|
249
|
-
index = jnp.arange(split_tree.size, dtype=minimal_unsigned_dtype(2 * split_tree.size - 1))
|
|
223
|
+
index = jnp.arange(split_tree.size, dtype=jaxext.minimal_unsigned_dtype(2 * split_tree.size - 1))
|
|
250
224
|
left_index = index << 1 # left child
|
|
251
225
|
right_index = left_index + 1 # right child
|
|
252
226
|
left_leaf = split_tree.at[left_index].get(mode='fill', fill_value=0) == 0
|
|
@@ -278,25 +252,4 @@ def tree_depths(tree_length):
|
|
|
278
252
|
depth += 1
|
|
279
253
|
depths.append(depth - 1)
|
|
280
254
|
depths[0] = 0
|
|
281
|
-
return jnp.array(depths, minimal_unsigned_dtype(max(depths)))
|
|
282
|
-
|
|
283
|
-
def index_depth(index, tree_length):
|
|
284
|
-
"""
|
|
285
|
-
Return the depth of a node in a binary tree.
|
|
286
|
-
|
|
287
|
-
Parameters
|
|
288
|
-
----------
|
|
289
|
-
index : int
|
|
290
|
-
The index of the node.
|
|
291
|
-
tree_length : int
|
|
292
|
-
The length of the tree array, i.e., 2 ** d.
|
|
293
|
-
|
|
294
|
-
Returns
|
|
295
|
-
-------
|
|
296
|
-
depth : int
|
|
297
|
-
The depth of the node. The root node (index 1) has depth 0. The depth is
|
|
298
|
-
the position of the most significant non-zero bit in the index. If
|
|
299
|
-
``index == 0``, return -1.
|
|
300
|
-
"""
|
|
301
|
-
depths = tree_depths(tree_length)
|
|
302
|
-
return depths[index]
|
|
255
|
+
return jnp.array(depths, jaxext.minimal_unsigned_dtype(max(depths)))
|