bartz 0.2.1__tar.gz → 0.3.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bartz-0.3.0/PKG-INFO ADDED
@@ -0,0 +1,77 @@
1
+ Metadata-Version: 2.1
2
+ Name: bartz
3
+ Version: 0.3.0
4
+ Summary: A JAX implementation of BART
5
+ Home-page: https://github.com/Gattocrucco/bartz
6
+ License: MIT
7
+ Author: Giacomo Petrillo
8
+ Author-email: info@giacomopetrillo.com
9
+ Requires-Python: >=3.10,<4.0
10
+ Classifier: License :: OSI Approved :: MIT License
11
+ Classifier: Programming Language :: Python :: 3
12
+ Classifier: Programming Language :: Python :: 3.10
13
+ Classifier: Programming Language :: Python :: 3.11
14
+ Classifier: Programming Language :: Python :: 3.12
15
+ Requires-Dist: jax (>=0.4.23,<0.5.0)
16
+ Requires-Dist: jaxlib (>=0.4.23,<0.5.0)
17
+ Requires-Dist: numpy (>=1.25.2,<2.0.0)
18
+ Requires-Dist: scipy (>=1.11.4,<2.0.0)
19
+ Project-URL: Bug Tracker, https://github.com/Gattocrucco/bartz/issues
20
+ Project-URL: Repository, https://github.com/Gattocrucco/bartz
21
+ Description-Content-Type: text/markdown
22
+
23
+ [![PyPI](https://img.shields.io/pypi/v/bartz)](https://pypi.org/project/bartz/)
24
+
25
+ # BART vectoriZed
26
+
27
+ A branchless vectorized implementation of Bayesian Additive Regression Trees (BART) in JAX.
28
+
29
+ BART is a nonparametric Bayesian regression technique. Given predictors $X$ and responses $y$, BART finds a function to predict $y$ given $X$. The result of the inference is a sample of possible functions, representing the uncertainty over the determination of the function.
30
+
31
+ This Python module provides an implementation of BART that runs on GPU, to process large datasets faster. It is also good on CPU. Most other implementations of BART are for R, and run on CPU only.
32
+
33
+ On CPU, bartz runs at the speed of dbarts (the fastest implementation I know of), but using half the memory. On GPU, the speed premium depends on sample size; with 50000 datapoints and 5000 trees, on an Nvidia Tesla V100 GPU it's 12 times faster than an Apple M1 CPU, and this factor is linearly proportional to the number of datapoints.
34
+
35
+ ## Links
36
+
37
+ - [Documentation (latest release)](https://gattocrucco.github.io/bartz/docs)
38
+ - [Documentation (development version)](https://gattocrucco.github.io/bartz/docs-dev)
39
+ - [Repository](https://github.com/Gattocrucco/bartz)
40
+ - [Code coverage](https://gattocrucco.github.io/bartz/coverage)
41
+
42
+ ## Other BART packages
43
+
44
+ - [stochtree](https://github.com/StochasticTree) C++ library with R and Python bindings taylored to researchers who want to make their own BART variants
45
+ - [bnptools](https://github.com/rsparapa/bnptools) Feature-rich R packages for BART and some variants
46
+ - [dbarts](https://github.com/vdorie/dbarts) Fast R package
47
+ - [bartMachine](https://github.com/kapelner/bartMachine) Fast R package, supports missing predictors imputation
48
+ - [SoftBART](https://github.com/theodds/SoftBART) R package with a smooth version of BART
49
+ - [bcf](https://github.com/jaredsmurray/bcf) R package for a version of BART for causal inference
50
+ - [flexBART](https://github.com/skdeshpande91/flexBART) Fast R package, supports categorical predictors
51
+ - [flexBCF](https://github.com/skdeshpande91/flexBCF) R package, version of bcf optimized for large datasets
52
+ - [XBART](https://github.com/JingyuHe/XBART) R/Python package, XBART is a faster variant of BART
53
+ - [BART](https://github.com/JingyuHe/BART) R package, BART warm-started with XBART
54
+ - [XBCF](https://github.com/socket778/XBCF)
55
+ - [BayesTree](https://cran.r-project.org/package=BayesTree) R package, original BART implementation
56
+ - [bartCause](https://github.com/vdorie/bartCause) R package, pre-made BART-based workflows for causal inference
57
+ - [stan4bart](https://github.com/vdorie/stan4bart)
58
+ - [VCBART](https://github.com/skdeshpande91/VCBART)
59
+ - [monbart](https://github.com/jaredsmurray/monbart)
60
+ - [mBART](https://github.com/remcc/mBART_shlib)
61
+ - [SequentialBART](https://github.com/mjdaniels/SequentialBART)
62
+ - [sparseBART](https://github.com/cspanbauer/sparseBART)
63
+ - [pymc-bart](https://github.com/pymc-devs/pymc-bart)
64
+ - [semibart](https://github.com/zeldow/semibart)
65
+ - [CSP-BART](https://github.com/ebprado/CSP-BART)
66
+ - [AMBARTI](https://github.com/ebprado/AMBARTI)
67
+ - [MOTR-BART](https://github.com/ebprado/MOTR-BART)
68
+ - [bcfbma](https://github.com/EoghanONeill/bcfbma)
69
+ - [bartBMAnew](https://github.com/EoghanONeill/bartBMAnew)
70
+ - [BART-BMA](https://github.com/BelindaHernandez/BART-BMA) (superseded by bartBMAnew)
71
+ - [gpbart](https://github.com/MateusMaiaDS/gpbart)
72
+ - [GPBART](https://github.com/nchenderson/GPBART)
73
+ - [bartpy](https://github.com/JakeColtman/bartpy)
74
+ - [BayesTreePrior](https://github.com/AlexiaJM/BayesTreePrior)
75
+ - [BayesTree.jl](https://github.com/mathcg/BayesTree.jl)
76
+ - [longbet](https://github.com/google/longbet)
77
+
bartz-0.3.0/README.md ADDED
@@ -0,0 +1,54 @@
1
+ [![PyPI](https://img.shields.io/pypi/v/bartz)](https://pypi.org/project/bartz/)
2
+
3
+ # BART vectoriZed
4
+
5
+ A branchless vectorized implementation of Bayesian Additive Regression Trees (BART) in JAX.
6
+
7
+ BART is a nonparametric Bayesian regression technique. Given predictors $X$ and responses $y$, BART finds a function to predict $y$ given $X$. The result of the inference is a sample of possible functions, representing the uncertainty over the determination of the function.
8
+
9
+ This Python module provides an implementation of BART that runs on GPU, to process large datasets faster. It is also good on CPU. Most other implementations of BART are for R, and run on CPU only.
10
+
11
+ On CPU, bartz runs at the speed of dbarts (the fastest implementation I know of), but using half the memory. On GPU, the speed premium depends on sample size; with 50000 datapoints and 5000 trees, on an Nvidia Tesla V100 GPU it's 12 times faster than an Apple M1 CPU, and this factor is linearly proportional to the number of datapoints.
12
+
13
+ ## Links
14
+
15
+ - [Documentation (latest release)](https://gattocrucco.github.io/bartz/docs)
16
+ - [Documentation (development version)](https://gattocrucco.github.io/bartz/docs-dev)
17
+ - [Repository](https://github.com/Gattocrucco/bartz)
18
+ - [Code coverage](https://gattocrucco.github.io/bartz/coverage)
19
+
20
+ ## Other BART packages
21
+
22
+ - [stochtree](https://github.com/StochasticTree) C++ library with R and Python bindings taylored to researchers who want to make their own BART variants
23
+ - [bnptools](https://github.com/rsparapa/bnptools) Feature-rich R packages for BART and some variants
24
+ - [dbarts](https://github.com/vdorie/dbarts) Fast R package
25
+ - [bartMachine](https://github.com/kapelner/bartMachine) Fast R package, supports missing predictors imputation
26
+ - [SoftBART](https://github.com/theodds/SoftBART) R package with a smooth version of BART
27
+ - [bcf](https://github.com/jaredsmurray/bcf) R package for a version of BART for causal inference
28
+ - [flexBART](https://github.com/skdeshpande91/flexBART) Fast R package, supports categorical predictors
29
+ - [flexBCF](https://github.com/skdeshpande91/flexBCF) R package, version of bcf optimized for large datasets
30
+ - [XBART](https://github.com/JingyuHe/XBART) R/Python package, XBART is a faster variant of BART
31
+ - [BART](https://github.com/JingyuHe/BART) R package, BART warm-started with XBART
32
+ - [XBCF](https://github.com/socket778/XBCF)
33
+ - [BayesTree](https://cran.r-project.org/package=BayesTree) R package, original BART implementation
34
+ - [bartCause](https://github.com/vdorie/bartCause) R package, pre-made BART-based workflows for causal inference
35
+ - [stan4bart](https://github.com/vdorie/stan4bart)
36
+ - [VCBART](https://github.com/skdeshpande91/VCBART)
37
+ - [monbart](https://github.com/jaredsmurray/monbart)
38
+ - [mBART](https://github.com/remcc/mBART_shlib)
39
+ - [SequentialBART](https://github.com/mjdaniels/SequentialBART)
40
+ - [sparseBART](https://github.com/cspanbauer/sparseBART)
41
+ - [pymc-bart](https://github.com/pymc-devs/pymc-bart)
42
+ - [semibart](https://github.com/zeldow/semibart)
43
+ - [CSP-BART](https://github.com/ebprado/CSP-BART)
44
+ - [AMBARTI](https://github.com/ebprado/AMBARTI)
45
+ - [MOTR-BART](https://github.com/ebprado/MOTR-BART)
46
+ - [bcfbma](https://github.com/EoghanONeill/bcfbma)
47
+ - [bartBMAnew](https://github.com/EoghanONeill/bartBMAnew)
48
+ - [BART-BMA](https://github.com/BelindaHernandez/BART-BMA) (superseded by bartBMAnew)
49
+ - [gpbart](https://github.com/MateusMaiaDS/gpbart)
50
+ - [GPBART](https://github.com/nchenderson/GPBART)
51
+ - [bartpy](https://github.com/JakeColtman/bartpy)
52
+ - [BayesTreePrior](https://github.com/AlexiaJM/BayesTreePrior)
53
+ - [BayesTree.jl](https://github.com/mathcg/BayesTree.jl)
54
+ - [longbet](https://github.com/google/longbet)
@@ -10,10 +10,10 @@
10
10
  # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11
11
  # copies of the Software, and to permit persons to whom the Software is
12
12
  # furnished to do so, subject to the following conditions:
13
- #
13
+ #
14
14
  # The above copyright notice and this permission notice shall be included in all
15
15
  # copies or substantial portions of the Software.
16
- #
16
+ #
17
17
  # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18
18
  # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19
19
  # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
@@ -28,7 +28,7 @@ build-backend = "poetry.core.masonry.api"
28
28
 
29
29
  [tool.poetry]
30
30
  name = "bartz"
31
- version = "0.2.1"
31
+ version = "0.3.0"
32
32
  description = "A JAX implementation of BART"
33
33
  authors = ["Giacomo Petrillo <info@giacomopetrillo.com>"]
34
34
  license = "MIT"
@@ -54,6 +54,8 @@ matplotlib = "^3.8.3"
54
54
  appnope = "^0.1.4"
55
55
  tomli = "^2.0.1"
56
56
  packaging = "^24.0"
57
+ xgboost = "^2.0.3"
58
+ pre-commit = "^3.7.0"
57
59
 
58
60
  [tool.poetry.group.test.dependencies]
59
61
  coverage = "^7.4.3"
@@ -10,10 +10,10 @@
10
10
  # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11
11
  # copies of the Software, and to permit persons to whom the Software is
12
12
  # furnished to do so, subject to the following conditions:
13
- #
13
+ #
14
14
  # The above copyright notice and this permission notice shall be included in all
15
15
  # copies or substantial portions of the Software.
16
- #
16
+ #
17
17
  # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18
18
  # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19
19
  # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
@@ -49,6 +49,9 @@ class gbart:
49
49
  The training responses.
50
50
  x_test : array (p, m) or DataFrame, optional
51
51
  The test predictors.
52
+ usequants : bool, default False
53
+ Whether to use predictors quantiles instead of a uniform grid to bin
54
+ predictors.
52
55
  sigest : float, optional
53
56
  An estimate of the residual standard deviation on `y_train`, used to
54
57
  set `lamda`. If not specified, it is estimated by linear regression.
@@ -82,10 +85,16 @@ class gbart:
82
85
  ntree : int, default 200
83
86
  The number of trees used to represent the latent mean function.
84
87
  numcut : int, default 255
85
- The maximum number of cutpoints to use for binning the predictors. Each
86
- predictor is binned such that its distribution in `x_train` is
87
- approximately uniform across bins. The number of bins is at most the
88
- number of unique values appearing in `x_train`, or ``numcut + 1``.
88
+ If `usequants` is `False`: the exact number of cutpoints used to bin the
89
+ predictors, ranging between the minimum and maximum observed values
90
+ (excluded).
91
+
92
+ If `usequants` is `True`: the maximum number of cutpoints to use for
93
+ binning the predictors. Each predictor is binned such that its
94
+ distribution in `x_train` is approximately uniform across bins. The
95
+ number of bins is at most the number of unique values appearing in
96
+ `x_train`, or ``numcut + 1``.
97
+
89
98
  Before running the algorithm, the predictors are compressed to the
90
99
  smallest integer type that fits the bin indices, so `numcut` is best set
91
100
  to the maximum value of an unsigned integer type.
@@ -126,6 +135,8 @@ class gbart:
126
135
  The number of trees.
127
136
  maxdepth : int
128
137
  The maximum depth of the trees.
138
+ initkw : dict
139
+ Additional arguments passed to `mcmcstep.init`.
129
140
 
130
141
  Methods
131
142
  -------
@@ -133,21 +144,26 @@ class gbart:
133
144
 
134
145
  Notes
135
146
  -----
136
- This interface imitates the function `gbart` from the R package `BART
147
+ This interface imitates the function ``gbart`` from the R package `BART
137
148
  <https://cran.r-project.org/package=BART>`_, but with these differences:
138
149
 
139
150
  - If `x_train` and `x_test` are matrices, they have one predictor per row
140
151
  instead of per column.
152
+ - If ``usequants=False``, R BART switches to quantiles anyway if there are
153
+ less predictor values than the required number of bins, while bartz
154
+ always follows the specification.
141
155
  - The error variance parameter is called `lamda` instead of `lambda`.
142
- - `usequants` is always `True`.
143
156
  - `rm_const` is always `False`.
144
157
  - The default `numcut` is 255 instead of 100.
145
158
  - A lot of functionality is missing (variable selection, discrete response).
146
159
  - There are some additional attributes, and some missing.
160
+
161
+ The linear regression used to set `sigest` adds an intercept.
147
162
  """
148
163
 
149
164
  def __init__(self, x_train, y_train, *,
150
165
  x_test=None,
166
+ usequants=False,
151
167
  sigest=None,
152
168
  sigdf=3,
153
169
  sigquant=0.9,
@@ -164,24 +180,25 @@ class gbart:
164
180
  keepevery=1,
165
181
  printevery=100,
166
182
  seed=0,
183
+ initkw={},
167
184
  ):
168
185
 
169
186
  x_train, x_train_fmt = self._process_predictor_input(x_train)
170
-
187
+
171
188
  y_train, y_train_fmt = self._process_response_input(y_train)
172
189
  self._check_same_length(x_train, y_train)
173
-
190
+
174
191
  offset = self._process_offset_settings(y_train, offset)
175
192
  scale = self._process_scale_settings(y_train, k)
176
193
  lamda, sigest = self._process_noise_variance_settings(x_train, y_train, sigest, sigdf, sigquant, lamda, offset)
177
194
 
178
- splits, max_split = self._determine_splits(x_train, numcut)
195
+ splits, max_split = self._determine_splits(x_train, usequants, numcut)
179
196
  x_train = self._bin_predictors(x_train, splits)
180
197
 
181
198
  y_train = self._transform_input(y_train, offset, scale)
182
199
  lamda_scaled = lamda / (scale * scale)
183
200
 
184
- mcmc_state = self._setup_mcmc(x_train, y_train, max_split, lamda_scaled, sigdf, power, base, maxdepth, ntree)
201
+ mcmc_state = self._setup_mcmc(x_train, y_train, max_split, lamda_scaled, sigdf, power, base, maxdepth, ntree, initkw)
185
202
  final_state, burnin_trace, main_trace = self._run_mcmc(mcmc_state, ndpost, nskip, keepevery, printevery, seed)
186
203
 
187
204
  sigma = self._extract_sigma(main_trace, scale)
@@ -279,7 +296,10 @@ class gbart:
279
296
  elif y_train.size <= x_train.shape[0]:
280
297
  sigest2 = jnp.var(y_train - offset)
281
298
  else:
282
- _, chisq, rank, _ = jnp.linalg.lstsq(x_train.T, y_train - offset)
299
+ x_centered = x_train.T - x_train.mean(axis=1)
300
+ y_centered = y_train - y_train.mean()
301
+ # centering is equivalent to adding an intercept column
302
+ _, chisq, rank, _ = jnp.linalg.lstsq(x_centered, y_centered)
283
303
  chisq = chisq.squeeze(0)
284
304
  dof = len(y_train) - rank
285
305
  sigest2 = chisq / dof
@@ -305,8 +325,11 @@ class gbart:
305
325
  return (y_train.max() - y_train.min()) / (2 * k)
306
326
 
307
327
  @staticmethod
308
- def _determine_splits(x_train, numcut):
309
- return prepcovars.quantilized_splits_from_matrix(x_train, numcut + 1)
328
+ def _determine_splits(x_train, usequants, numcut):
329
+ if usequants:
330
+ return prepcovars.quantilized_splits_from_matrix(x_train, numcut + 1)
331
+ else:
332
+ return prepcovars.uniform_splits_from_matrix(x_train, numcut + 1)
310
333
 
311
334
  @staticmethod
312
335
  def _bin_predictors(x, splits):
@@ -317,12 +340,12 @@ class gbart:
317
340
  return (y - offset) / scale
318
341
 
319
342
  @staticmethod
320
- def _setup_mcmc(x_train, y_train, max_split, lamda, sigdf, power, base, maxdepth, ntree):
343
+ def _setup_mcmc(x_train, y_train, max_split, lamda, sigdf, power, base, maxdepth, ntree, initkw):
321
344
  depth = jnp.arange(maxdepth - 1)
322
345
  p_nonterminal = base / (1 + depth).astype(float) ** power
323
346
  sigma2_alpha = sigdf / 2
324
347
  sigma2_beta = lamda * sigma2_alpha
325
- return mcmcstep.init(
348
+ kw = dict(
326
349
  X=x_train,
327
350
  y=y_train,
328
351
  max_split=max_split,
@@ -332,6 +355,8 @@ class gbart:
332
355
  sigma2_beta=sigma2_beta,
333
356
  min_points_per_leaf=5,
334
357
  )
358
+ kw.update(initkw)
359
+ return mcmcstep.init(**kw)
335
360
 
336
361
  @staticmethod
337
362
  def _run_mcmc(mcmc_state, ndpost, nskip, keepevery, printevery, seed):
@@ -354,7 +379,7 @@ class gbart:
354
379
  def _extract_sigma(trace, scale):
355
380
  return scale * jnp.sqrt(trace['sigma2'])
356
381
 
357
-
382
+
358
383
  def _show_tree(self, i_sample, i_tree, print_all=False):
359
384
  from . import debug
360
385
  trace = self._main_trace
@@ -0,0 +1 @@
1
+ __version__ = '0.3.0'
@@ -10,10 +10,10 @@
10
10
  # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11
11
  # copies of the Software, and to permit persons to whom the Software is
12
12
  # furnished to do so, subject to the following conditions:
13
- #
13
+ #
14
14
  # The above copyright notice and this permission notice shall be included in all
15
15
  # copies or substantial portions of the Software.
16
- #
16
+ #
17
17
  # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18
18
  # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19
19
  # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
@@ -114,7 +114,7 @@ def traverse_tree(x, var_tree, split_tree):
114
114
 
115
115
  split = split_tree[index]
116
116
  var = var_tree[index]
117
-
117
+
118
118
  leaf_found |= split == 0
119
119
  child_index = (index << 1) + (x[var] >= split)
120
120
  index = jnp.where(leaf_found, index, child_index)
@@ -147,7 +147,7 @@ def traverse_forest(X, var_trees, split_trees):
147
147
  """
148
148
  return traverse_tree(X, var_trees, split_trees)
149
149
 
150
- def evaluate_forest(X, leaf_trees, var_trees, split_trees, dtype):
150
+ def evaluate_forest(X, leaf_trees, var_trees, split_trees, dtype=None, sum_trees=True):
151
151
  """
152
152
  Evaluate a ensemble of trees at an array of points.
153
153
 
@@ -162,21 +162,26 @@ def evaluate_forest(X, leaf_trees, var_trees, split_trees, dtype):
162
162
  The decision axes of the trees.
163
163
  split_trees : array (m, 2 ** (d - 1))
164
164
  The decision boundaries of the trees.
165
- dtype : dtype
166
- The dtype of the output.
165
+ dtype : dtype, optional
166
+ The dtype of the output. Ignored if `sum_trees` is `False`.
167
+ sum_trees : bool, default True
168
+ Whether to sum the values across trees.
167
169
 
168
170
  Returns
169
171
  -------
170
- out : array (n,)
171
- The sum of the values of the trees at the points in `X`.
172
+ out : array (n,) or (m, n)
173
+ The (sum of) the values of the trees at the points in `X`.
172
174
  """
173
175
  indices = traverse_forest(X, var_trees, split_trees)
174
176
  ntree, _ = leaf_trees.shape
175
- tree_index = jnp.arange(ntree, dtype=jaxext.minimal_unsigned_dtype(ntree - 1))[:, None]
176
- leaves = leaf_trees[tree_index, indices]
177
- return jnp.sum(leaves, axis=0, dtype=dtype)
178
- # this sum suggests to swap the vmaps, but I think it's better for X
179
- # copying to keep it that way
177
+ tree_index = jnp.arange(ntree, dtype=jaxext.minimal_unsigned_dtype(ntree - 1))
178
+ leaves = leaf_trees[tree_index[:, None], indices]
179
+ if sum_trees:
180
+ return jnp.sum(leaves, axis=0, dtype=dtype)
181
+ # this sum suggests to swap the vmaps, but I think it's better for X
182
+ # copying to keep it that way
183
+ else:
184
+ return leaves
180
185
 
181
186
  def is_actual_leaf(split_tree, *, add_bottom_level=False):
182
187
  """
@@ -238,7 +243,7 @@ def tree_depths(tree_length):
238
243
  tree_length : int
239
244
  The length of the tree array, i.e., 2 ** d.
240
245
 
241
- Returns
246
+ Returns
242
247
  -------
243
248
  depth : array (tree_length,)
244
249
  The depth of each node. The root node (index 1) has depth 0. The depth
@@ -196,13 +196,14 @@ def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False)
196
196
  A jittable function with positional arguments only, with inputs and
197
197
  outputs pytrees of arrays.
198
198
  max_io_nbytes : int
199
- The maximum number of input + output bytes in each batch.
200
- in_axes : pytree of ints, default 0
199
+ The maximum number of input + output bytes in each batch (excluding
200
+ unbatched arguments.)
201
+ in_axes : pytree of int or None, default 0
201
202
  A tree matching the structure of the function input, indicating along
202
203
  which axes each array should be batched. If a single integer, it is
203
- used for all arrays.
204
+ used for all arrays. A `None` axis indicates to not batch an argument.
204
205
  out_axes : pytree of ints, default 0
205
- The same for outputs.
206
+ The same for outputs (but non-batching is not allowed).
206
207
  return_nbatches : bool, default False
207
208
  If True, the number of batches is returned as a second output.
208
209
 
@@ -218,8 +219,18 @@ def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False)
218
219
  return tree_util.tree_map(lambda _: axes, tree)
219
220
  return tree_util.tree_map(lambda _, axis: axis, tree, axes)
220
221
 
222
+ def check_no_nones(axes, tree):
223
+ def check_not_none(_, axis):
224
+ assert axis is not None
225
+ tree_util.tree_map(check_not_none, tree, axes)
226
+
221
227
  def extract_size(axes, tree):
222
- sizes = tree_util.tree_map(lambda x, axis: x.shape[axis], tree, axes)
228
+ def get_size(x, axis):
229
+ if axis is None:
230
+ return None
231
+ else:
232
+ return x.shape[axis]
233
+ sizes = tree_util.tree_map(get_size, tree, axes)
223
234
  sizes, _ = tree_util.tree_flatten(sizes)
224
235
  assert all(s == sizes[0] for s in sizes)
225
236
  return sizes[0]
@@ -243,23 +254,37 @@ def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False)
243
254
  return dividend
244
255
 
245
256
  def next_divisor(dividend, min_divisor):
257
+ if dividend == 0:
258
+ return min_divisor
246
259
  if min_divisor * min_divisor <= dividend:
247
260
  return next_divisor_small(dividend, min_divisor)
248
261
  return next_divisor_large(dividend, min_divisor)
249
262
 
263
+ def pull_nonbatched(axes, tree):
264
+ def pull_nonbatched(x, axis):
265
+ if axis is None:
266
+ return None
267
+ else:
268
+ return x
269
+ return tree_util.tree_map(pull_nonbatched, tree, axes), tree
270
+
271
+ def push_nonbatched(axes, tree, original_tree):
272
+ def push_nonbatched(original_x, x, axis):
273
+ if axis is None:
274
+ return original_x
275
+ else:
276
+ return x
277
+ return tree_util.tree_map(push_nonbatched, original_tree, tree, axes)
278
+
250
279
  def move_axes_out(axes, tree):
251
- def move_axis_out(axis, x):
252
- if axis != 0:
253
- return jnp.moveaxis(x, axis, 0)
254
- return x
255
- return tree_util.tree_map(move_axis_out, axes, tree)
280
+ def move_axis_out(x, axis):
281
+ return jnp.moveaxis(x, axis, 0)
282
+ return tree_util.tree_map(move_axis_out, tree, axes)
256
283
 
257
284
  def move_axes_in(axes, tree):
258
- def move_axis_in(axis, x):
259
- if axis != 0:
260
- return jnp.moveaxis(x, 0, axis)
261
- return x
262
- return tree_util.tree_map(move_axis_in, axes, tree)
285
+ def move_axis_in(x, axis):
286
+ return jnp.moveaxis(x, 0, axis)
287
+ return tree_util.tree_map(move_axis_in, tree, axes)
263
288
 
264
289
  def batch(tree, nbatches):
265
290
  def batch(x):
@@ -287,16 +312,17 @@ def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False)
287
312
 
288
313
  in_axes = expand_axes(initial_in_axes, args)
289
314
  out_axes = expand_axes(initial_out_axes, example_result)
315
+ check_no_nones(out_axes, example_result)
316
+
317
+ size = extract_size((in_axes, out_axes), (args, example_result))
290
318
 
291
- in_size = extract_size(in_axes, args)
292
- out_size = extract_size(out_axes, example_result)
293
- assert in_size == out_size
294
- size = in_size
319
+ args, nonbatched_args = pull_nonbatched(in_axes, args)
295
320
 
296
- total_nbytes = sum_nbytes(args) + sum_nbytes(example_result)
321
+ total_nbytes = sum_nbytes((args, example_result))
297
322
  min_nbatches = total_nbytes // max_io_nbytes + bool(total_nbytes % max_io_nbytes)
323
+ min_nbatches = max(1, min_nbatches)
298
324
  nbatches = next_divisor(size, min_nbatches)
299
- assert 1 <= nbatches <= size
325
+ assert 1 <= nbatches <= max(1, size)
300
326
  assert size % nbatches == 0
301
327
  assert total_nbytes % nbatches == 0
302
328
 
@@ -307,6 +333,7 @@ def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False)
307
333
 
308
334
  def loop(_, args):
309
335
  args = move_axes_in(in_axes, args)
336
+ args = push_nonbatched(in_axes, args, nonbatched_args)
310
337
  result = func(*args)
311
338
  result = move_axes_out(out_axes, result)
312
339
  return None, result
@@ -10,10 +10,10 @@
10
10
  # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11
11
  # copies of the Software, and to permit persons to whom the Software is
12
12
  # furnished to do so, subject to the following conditions:
13
- #
13
+ #
14
14
  # The above copyright notice and this permission notice shall be included in all
15
15
  # copies or substantial portions of the Software.
16
- #
16
+ #
17
17
  # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18
18
  # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19
19
  # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
@@ -34,8 +34,9 @@ from jax import debug
34
34
  from jax import numpy as jnp
35
35
  from jax import lax
36
36
 
37
- from . import mcmcstep
37
+ from . import jaxext
38
38
  from . import grove
39
+ from . import mcmcstep
39
40
 
40
41
  @functools.partial(jax.jit, static_argnums=(1, 2, 3, 4))
41
42
  def run_mcmc(bart, n_burn, n_save, n_skip, callback, key):
@@ -91,7 +92,7 @@ def run_mcmc(bart, n_burn, n_save, n_skip, callback, key):
91
92
  the fields in `burnin_trace`.
92
93
  """
93
94
 
94
- tracelist_burnin = 'sigma2', 'grow_prop_count', 'grow_acc_count', 'prune_prop_count', 'prune_acc_count'
95
+ tracelist_burnin = 'sigma2', 'grow_prop_count', 'grow_acc_count', 'prune_prop_count', 'prune_acc_count', 'ratios'
95
96
 
96
97
  tracelist_main = tracelist_burnin + ('leaf_trees', 'var_trees', 'split_trees')
97
98
 
@@ -102,14 +103,11 @@ def run_mcmc(bart, n_burn, n_save, n_skip, callback, key):
102
103
  key, subkey = random.split(key)
103
104
  bart = mcmcstep.step(bart, subkey)
104
105
  callback(bart=bart, burnin=burnin, i_total=i_total, i_skip=i_skip, **callback_kw)
105
- output = {key: bart[key] for key in tracelist}
106
+ output = {key: bart[key] for key in tracelist if key in bart}
106
107
  return (bart, i_total + 1, i_skip + 1, key), output
107
108
 
108
109
  def empty_trace(bart, tracelist):
109
- return {
110
- key: jnp.empty((0,) + bart[key].shape, bart[key].dtype)
111
- for key in tracelist
112
- }
110
+ return jax.vmap(lambda x: x, in_axes=None, out_axes=0, axis_size=0)(bart)
113
111
 
114
112
  if n_burn > 0:
115
113
  carry = bart, 0, 0, key
@@ -124,7 +122,7 @@ def run_mcmc(bart, n_burn, n_save, n_skip, callback, key):
124
122
  main_loop = functools.partial(inner_loop, tracelist=[], burnin=False)
125
123
  inner_carry = bart, i_total, 0, key
126
124
  (bart, i_total, _, key), _ = lax.scan(main_loop, inner_carry, None, n_skip)
127
- output = {key: bart[key] for key in tracelist_main}
125
+ output = {key: bart[key] for key in tracelist_main if key in bart}
128
126
  return (bart, i_total, key), output
129
127
 
130
128
  if n_save > 0:
@@ -135,12 +133,9 @@ def run_mcmc(bart, n_burn, n_save, n_skip, callback, key):
135
133
 
136
134
  return bart, burnin_trace, main_trace
137
135
 
138
- # TODO I could add an argument callback_state to carry over state. This would allow e.g. accumulating counts. If I made the callback return the mcmc state, I could modify the mcmc from the callback.
139
-
140
136
  @functools.lru_cache
141
137
  # cache to make the callback function object unique, such that the jit
142
- # of run_mcmc recognizes it => with the callback state, I can make
143
- # printevery a runtime quantity
138
+ # of run_mcmc recognizes it
144
139
  def make_simple_print_callback(printevery):
145
140
  """
146
141
  Create a logging callback function for MCMC iterations.
@@ -193,7 +188,10 @@ def evaluate_trace(trace, X):
193
188
  y : array (n_trace, n)
194
189
  The predictions for each iteration of the MCMC.
195
190
  """
191
+ evaluate_trees = functools.partial(grove.evaluate_forest, sum_trees=False)
192
+ evaluate_trees = jaxext.autobatch(evaluate_trees, 2 ** 29, (None, 0, 0, 0))
196
193
  def loop(_, state):
197
- return None, grove.evaluate_forest(X, state['leaf_trees'], state['var_trees'], state['split_trees'], jnp.float32)
194
+ values = evaluate_trees(X, state['leaf_trees'], state['var_trees'], state['split_trees'])
195
+ return None, jnp.sum(values, axis=0, dtype=jnp.float32)
198
196
  _, y = lax.scan(loop, None, trace)
199
197
  return y