bartz 0.1.0__tar.gz → 0.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: bartz
3
- Version: 0.1.0
3
+ Version: 0.2.0
4
4
  Summary: A JAX implementation of BART
5
5
  Home-page: https://github.com/Gattocrucco/bartz
6
6
  License: MIT
@@ -20,7 +20,13 @@ Project-URL: Bug Tracker, https://github.com/Gattocrucco/bartz/issues
20
20
  Project-URL: Repository, https://github.com/Gattocrucco/bartz
21
21
  Description-Content-Type: text/markdown
22
22
 
23
+ [![PyPI](https://img.shields.io/pypi/v/bartz)](https://pypi.org/project/bartz/)
24
+
23
25
  # BART vectoriZed
24
26
 
25
27
  A branchless vectorized implementation of Bayesian Additive Regression Trees (BART) in JAX.
26
28
 
29
+ BART is a nonparametric Bayesian regression technique. Given predictors $X$ and responses $y$, BART finds a function to predict $y$ given $X$. The result of the inference is a sample of possible functions, representing the uncertainty over the determination of the function.
30
+
31
+ This Python module provides an implementation of BART that runs on GPU, to process large datasets faster. It is also a good on CPU. Most other implementations of BART are for R, and run on CPU only.
32
+
bartz-0.2.0/README.md ADDED
@@ -0,0 +1,9 @@
1
+ [![PyPI](https://img.shields.io/pypi/v/bartz)](https://pypi.org/project/bartz/)
2
+
3
+ # BART vectoriZed
4
+
5
+ A branchless vectorized implementation of Bayesian Additive Regression Trees (BART) in JAX.
6
+
7
+ BART is a nonparametric Bayesian regression technique. Given predictors $X$ and responses $y$, BART finds a function to predict $y$ given $X$. The result of the inference is a sample of possible functions, representing the uncertainty over the determination of the function.
8
+
9
+ This Python module provides an implementation of BART that runs on GPU, to process large datasets faster. It is also a good on CPU. Most other implementations of BART are for R, and run on CPU only.
@@ -28,7 +28,7 @@ build-backend = "poetry.core.masonry.api"
28
28
 
29
29
  [tool.poetry]
30
30
  name = "bartz"
31
- version = "0.1.0"
31
+ version = "0.2.0"
32
32
  description = "A JAX implementation of BART"
33
33
  authors = ["Giacomo Petrillo <info@giacomopetrillo.com>"]
34
34
  license = "MIT"
@@ -53,6 +53,7 @@ ipython = "^8.22.2"
53
53
  matplotlib = "^3.8.3"
54
54
  appnope = "^0.1.4"
55
55
  tomli = "^2.0.1"
56
+ packaging = "^24.0"
56
57
 
57
58
  [tool.poetry.group.test.dependencies]
58
59
  coverage = "^7.4.3"
@@ -60,7 +61,7 @@ pytest = "^8.1.1"
60
61
 
61
62
  [tool.poetry.group.docs.dependencies]
62
63
  Sphinx = "^7.2.6"
63
- numpydoc = "^1.6.0"
64
+ numpydoc = "^1.6.0,<1.7.0" # 1.7.0 breaks linkcode, it seems
64
65
  myst-parser = "^2.0.0"
65
66
 
66
67
  [tool.pytest.ini_options]
@@ -1,4 +1,4 @@
1
- # bartz/src/bartz/interface.py
1
+ # bartz/src/bartz/BART.py
2
2
  #
3
3
  # Copyright (c) 2024, Giacomo Petrillo
4
4
  #
@@ -33,7 +33,7 @@ from . import mcmcstep
33
33
  from . import mcmcloop
34
34
  from . import prepcovars
35
35
 
36
- class BART:
36
+ class gbart:
37
37
  """
38
38
  Nonparametric regression with Bayesian Additive Regression Trees (BART).
39
39
 
@@ -133,7 +133,7 @@ class BART:
133
133
 
134
134
  Notes
135
135
  -----
136
- This interface imitates the function `wbart` from the R package `BART
136
+ This interface imitates the function `gbart` from the R package `BART
137
137
  <https://cran.r-project.org/package=BART>`_, but with these differences:
138
138
 
139
139
  - If `x_train` and `x_test` are matrices, they have one predictor per row
@@ -142,6 +142,7 @@ class BART:
142
142
  - `usequants` is always `True`.
143
143
  - `rm_const` is always `False`.
144
144
  - The default `numcut` is 255 instead of 100.
145
+ - A lot of functionality is missing (variable selection, discrete response).
145
146
  - There are some additional attributes, and some missing.
146
147
  """
147
148
 
@@ -30,6 +30,11 @@ See the manual at https://gattocrucco.github.io/bartz/docs
30
30
 
31
31
  from ._version import __version__
32
32
 
33
- from .interface import BART
33
+ from . import BART
34
34
 
35
35
  from . import debug
36
+ from . import grove
37
+ from . import mcmcstep
38
+ from . import mcmcloop
39
+ from . import prepcovars
40
+ from . import jaxext
@@ -0,0 +1 @@
1
+ __version__ = '0.2.0'
@@ -6,6 +6,7 @@ from jax import lax
6
6
 
7
7
  from . import grove
8
8
  from . import mcmcstep
9
+ from . import jaxext
9
10
 
10
11
  def print_tree(leaf_tree, var_tree, split_tree, print_all=False):
11
12
 
@@ -83,7 +84,7 @@ def trace_depth_distr(split_trees_trace):
83
84
  def points_per_leaf_distr(var_tree, split_tree, X):
84
85
  traverse_tree = jax.vmap(grove.traverse_tree, in_axes=(1, None, None))
85
86
  indices = traverse_tree(X, var_tree, split_tree)
86
- count_tree = jnp.zeros(2 * split_tree.size, dtype=grove.minimal_unsigned_dtype(indices.size))
87
+ count_tree = jnp.zeros(2 * split_tree.size, dtype=jaxext.minimal_unsigned_dtype(indices.size))
87
88
  count_tree = count_tree.at[indices].add(1)
88
89
  is_leaf = grove.is_actual_leaf(split_tree, add_bottom_level=True).view(jnp.uint8)
89
90
  return jnp.bincount(count_tree, is_leaf, length=X.shape[1] + 1)
@@ -103,7 +104,7 @@ def trace_points_per_leaf_distr(bart, X):
103
104
  return distr
104
105
 
105
106
  def check_types(leaf_tree, var_tree, split_tree, max_split):
106
- expected_var_dtype = grove.minimal_unsigned_dtype(max_split.size - 1)
107
+ expected_var_dtype = jaxext.minimal_unsigned_dtype(max_split.size - 1)
107
108
  expected_split_dtype = max_split.dtype
108
109
  return var_tree.dtype == expected_var_dtype and split_tree.dtype == expected_split_dtype
109
110
 
@@ -117,7 +118,7 @@ def check_leaf_values(leaf_tree, var_tree, split_tree, max_split):
117
118
  return jnp.all(jnp.isfinite(leaf_tree))
118
119
 
119
120
  def check_stray_nodes(leaf_tree, var_tree, split_tree, max_split):
120
- index = jnp.arange(2 * split_tree.size, dtype=grove.minimal_unsigned_dtype(2 * split_tree.size - 1))
121
+ index = jnp.arange(2 * split_tree.size, dtype=jaxext.minimal_unsigned_dtype(2 * split_tree.size - 1))
121
122
  parent_index = index >> 1
122
123
  is_not_leaf = split_tree.at[index].get(mode='fill', fill_value=0) != 0
123
124
  parent_is_leaf = split_tree[parent_index] == 0
@@ -134,7 +135,7 @@ check_functions = [
134
135
  ]
135
136
 
136
137
  def check_tree(leaf_tree, var_tree, split_tree, max_split):
137
- error_type = grove.minimal_unsigned_dtype(2 ** len(check_functions) - 1)
138
+ error_type = jaxext.minimal_unsigned_dtype(2 ** len(check_functions) - 1)
138
139
  error = error_type(0)
139
140
  for i, func in enumerate(check_functions):
140
141
  ok = func(leaf_tree, var_tree, split_tree, max_split)
@@ -44,7 +44,6 @@ import functools
44
44
  import math
45
45
 
46
46
  import jax
47
-
48
47
  from jax import numpy as jnp
49
48
  from jax import lax
50
49
 
@@ -107,29 +106,47 @@ def traverse_tree(x, var_tree, split_tree):
107
106
 
108
107
  carry = (
109
108
  jnp.zeros((), bool),
110
- jnp.ones((), minimal_unsigned_dtype(2 * var_tree.size - 1)),
109
+ jnp.ones((), jaxext.minimal_unsigned_dtype(2 * var_tree.size - 1)),
111
110
  )
112
111
 
113
112
  def loop(carry, _):
114
113
  leaf_found, index = carry
115
114
 
116
- split = split_tree.at[index].get(mode='fill', fill_value=0)
117
- var = var_tree.at[index].get(mode='fill', fill_value=0)
115
+ split = split_tree[index]
116
+ var = var_tree[index]
118
117
 
119
- leaf_found |= split_tree.at[index].get(mode='fill', fill_value=0) == 0
118
+ leaf_found |= split == 0
120
119
  child_index = (index << 1) + (x[var] >= split)
121
120
  index = jnp.where(leaf_found, index, child_index)
122
121
 
123
122
  return (leaf_found, index), None
124
123
 
125
- # TODO
126
- # - unroll (how much? 5?)
127
- # - separate and special-case the last iteration
128
-
129
- depth = 1 + tree_depth(var_tree)
130
- (_, index), _ = lax.scan(loop, carry, None, depth)
124
+ depth = tree_depth(var_tree)
125
+ (_, index), _ = lax.scan(loop, carry, None, depth, unroll=16)
131
126
  return index
132
127
 
128
+ @functools.partial(jaxext.vmap_nodoc, in_axes=(None, 0, 0))
129
+ @functools.partial(jaxext.vmap_nodoc, in_axes=(1, None, None))
130
+ def traverse_forest(X, var_trees, split_trees):
131
+ """
132
+ Find the leaves where points fall into.
133
+
134
+ Parameters
135
+ ----------
136
+ X : array (p, n)
137
+ The coordinates to evaluate the trees at.
138
+ var_trees : array (m, 2 ** (d - 1))
139
+ The decision axes of the trees.
140
+ split_trees : array (m, 2 ** (d - 1))
141
+ The decision boundaries of the trees.
142
+
143
+ Returns
144
+ -------
145
+ indices : array (m, n)
146
+ The indices of the leaves.
147
+ """
148
+ return traverse_tree(X, var_trees, split_trees)
149
+
133
150
  def evaluate_forest(X, leaf_trees, var_trees, split_trees, dtype):
134
151
  """
135
152
  Evaluate a ensemble of trees at an array of points.
@@ -138,7 +155,7 @@ def evaluate_forest(X, leaf_trees, var_trees, split_trees, dtype):
138
155
  ----------
139
156
  X : array (p, n)
140
157
  The coordinates to evaluate the trees at.
141
- leaf_trees : (m, 2 ** d)
158
+ leaf_trees : array (m, 2 ** d)
142
159
  The leaf values of the tree or forest. If the input is a forest, the
143
160
  first axis is the tree index, and the values are summed.
144
161
  var_trees : array (m, 2 ** (d - 1))
@@ -153,30 +170,13 @@ def evaluate_forest(X, leaf_trees, var_trees, split_trees, dtype):
153
170
  out : array (n,)
154
171
  The sum of the values of the trees at the points in `X`.
155
172
  """
156
- indices = _traverse_forest(X, var_trees, split_trees)
173
+ indices = traverse_forest(X, var_trees, split_trees)
157
174
  ntree, _ = leaf_trees.shape
158
- tree_index = jnp.arange(ntree, dtype=minimal_unsigned_dtype(ntree - 1))[:, None]
175
+ tree_index = jnp.arange(ntree, dtype=jaxext.minimal_unsigned_dtype(ntree - 1))[:, None]
159
176
  leaves = leaf_trees[tree_index, indices]
160
177
  return jnp.sum(leaves, axis=0, dtype=dtype)
161
- # this sum suggests to swap the vmaps, but I think it's better for X copying to keep it that way
162
-
163
- @functools.partial(jax.vmap, in_axes=(None, 0, 0))
164
- @functools.partial(jax.vmap, in_axes=(1, None, None))
165
- def _traverse_forest(X, var_trees, split_trees):
166
- return traverse_tree(X, var_trees, split_trees)
167
-
168
- def minimal_unsigned_dtype(max_value):
169
- """
170
- Return the smallest unsigned integer dtype that can represent a given
171
- maximum value.
172
- """
173
- if max_value < 2 ** 8:
174
- return jnp.uint8
175
- if max_value < 2 ** 16:
176
- return jnp.uint16
177
- if max_value < 2 ** 32:
178
- return jnp.uint32
179
- return jnp.uint64
178
+ # this sum suggests to swap the vmaps, but I think it's better for X
179
+ # copying to keep it that way
180
180
 
181
181
  def is_actual_leaf(split_tree, *, add_bottom_level=False):
182
182
  """
@@ -200,7 +200,7 @@ def is_actual_leaf(split_tree, *, add_bottom_level=False):
200
200
  if add_bottom_level:
201
201
  size *= 2
202
202
  is_leaf = jnp.concatenate([is_leaf, jnp.ones_like(is_leaf)])
203
- index = jnp.arange(size, dtype=minimal_unsigned_dtype(size - 1))
203
+ index = jnp.arange(size, dtype=jaxext.minimal_unsigned_dtype(size - 1))
204
204
  parent_index = index >> 1
205
205
  parent_nonleaf = split_tree[parent_index].astype(bool)
206
206
  parent_nonleaf = parent_nonleaf.at[1].set(True)
@@ -220,7 +220,7 @@ def is_leaves_parent(split_tree):
220
220
  is_leaves_parent : bool array (2 ** (d - 1),)
221
221
  The mask indicating which nodes have leaf children.
222
222
  """
223
- index = jnp.arange(split_tree.size, dtype=minimal_unsigned_dtype(2 * split_tree.size - 1))
223
+ index = jnp.arange(split_tree.size, dtype=jaxext.minimal_unsigned_dtype(2 * split_tree.size - 1))
224
224
  left_index = index << 1 # left child
225
225
  right_index = left_index + 1 # right child
226
226
  left_leaf = split_tree.at[left_index].get(mode='fill', fill_value=0) == 0
@@ -252,4 +252,4 @@ def tree_depths(tree_length):
252
252
  depth += 1
253
253
  depths.append(depth - 1)
254
254
  depths[0] = 0
255
- return jnp.array(depths, minimal_unsigned_dtype(max(depths)))
255
+ return jnp.array(depths, jaxext.minimal_unsigned_dtype(max(depths)))
@@ -0,0 +1,341 @@
1
+ # bartz/src/bartz/jaxext.py
2
+ #
3
+ # Copyright (c) 2024, Giacomo Petrillo
4
+ #
5
+ # This file is part of bartz.
6
+ #
7
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
8
+ # of this software and associated documentation files (the "Software"), to deal
9
+ # in the Software without restriction, including without limitation the rights
10
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11
+ # copies of the Software, and to permit persons to whom the Software is
12
+ # furnished to do so, subject to the following conditions:
13
+ #
14
+ # The above copyright notice and this permission notice shall be included in all
15
+ # copies or substantial portions of the Software.
16
+ #
17
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23
+ # SOFTWARE.
24
+
25
+ import functools
26
+ import math
27
+ import warnings
28
+
29
+ from scipy import special
30
+ import jax
31
+ from jax import numpy as jnp
32
+ from jax import tree_util
33
+ from jax import lax
34
+
35
+ def float_type(*args):
36
+ """
37
+ Determine the jax floating point result type given operands/types.
38
+ """
39
+ t = jnp.result_type(*args)
40
+ return jnp.sin(jnp.empty(0, t)).dtype
41
+
42
+ def castto(func, type):
43
+ @functools.wraps(func)
44
+ def newfunc(*args, **kw):
45
+ return func(*args, **kw).astype(type)
46
+ return newfunc
47
+
48
+ def pure_callback_ufunc(callback, dtype, *args, excluded=None, **kwargs):
49
+ """ version of `jax.pure_callback` that deals correctly with ufuncs,
50
+ see `<https://github.com/google/jax/issues/17187>`_ """
51
+ if excluded is None:
52
+ excluded = ()
53
+ shape = jnp.broadcast_shapes(*(
54
+ a.shape
55
+ for i, a in enumerate(args)
56
+ if i not in excluded
57
+ ))
58
+ ndim = len(shape)
59
+ padded_args = [
60
+ a if i in excluded
61
+ else jnp.expand_dims(a, tuple(range(ndim - a.ndim)))
62
+ for i, a in enumerate(args)
63
+ ]
64
+ result = jax.ShapeDtypeStruct(shape, dtype)
65
+ return jax.pure_callback(callback, result, *padded_args, vectorized=True, **kwargs)
66
+
67
+ # TODO when jax solves this, check version and piggyback on original if new
68
+
69
+ class scipy:
70
+
71
+ class special:
72
+
73
+ @functools.wraps(special.gammainccinv)
74
+ def gammainccinv(a, y):
75
+ a = jnp.asarray(a)
76
+ y = jnp.asarray(y)
77
+ dtype = float_type(a.dtype, y.dtype)
78
+ ufunc = castto(special.gammainccinv, dtype)
79
+ return pure_callback_ufunc(ufunc, dtype, a, y)
80
+
81
+ class stats:
82
+
83
+ class invgamma:
84
+
85
+ def ppf(q, a):
86
+ return 1 / scipy.special.gammainccinv(a, q)
87
+
88
+ @functools.wraps(jax.vmap)
89
+ def vmap_nodoc(fun, *args, **kw):
90
+ """
91
+ Version of `jax.vmap` that preserves the docstring of the input function.
92
+ """
93
+ doc = fun.__doc__
94
+ fun = jax.vmap(fun, *args, **kw)
95
+ fun.__doc__ = doc
96
+ return fun
97
+
98
+ def huge_value(x):
99
+ """
100
+ Return the maximum value that can be stored in `x`.
101
+
102
+ Parameters
103
+ ----------
104
+ x : array
105
+ A numerical numpy or jax array.
106
+
107
+ Returns
108
+ -------
109
+ maxval : scalar
110
+ The maximum value allowed by `x`'s type (+inf for floats).
111
+ """
112
+ if jnp.issubdtype(x.dtype, jnp.integer):
113
+ return jnp.iinfo(x.dtype).max
114
+ else:
115
+ return jnp.inf
116
+
117
+ def minimal_unsigned_dtype(max_value):
118
+ """
119
+ Return the smallest unsigned integer dtype that can represent a given
120
+ maximum value (inclusive).
121
+ """
122
+ if max_value < 2 ** 8:
123
+ return jnp.uint8
124
+ if max_value < 2 ** 16:
125
+ return jnp.uint16
126
+ if max_value < 2 ** 32:
127
+ return jnp.uint32
128
+ return jnp.uint64
129
+
130
+ def signed_to_unsigned(int_dtype):
131
+ """
132
+ Map a signed integer type to its unsigned counterpart. Unsigned types are
133
+ passed through.
134
+ """
135
+ assert jnp.issubdtype(int_dtype, jnp.integer)
136
+ if jnp.issubdtype(int_dtype, jnp.unsignedinteger):
137
+ return int_dtype
138
+ if int_dtype == jnp.int8:
139
+ return jnp.uint8
140
+ if int_dtype == jnp.int16:
141
+ return jnp.uint16
142
+ if int_dtype == jnp.int32:
143
+ return jnp.uint32
144
+ if int_dtype == jnp.int64:
145
+ return jnp.uint64
146
+
147
+ def ensure_unsigned(x):
148
+ """
149
+ If x has signed integer type, cast it to the unsigned dtype of the same size.
150
+ """
151
+ return x.astype(signed_to_unsigned(x.dtype))
152
+
153
+ @functools.partial(jax.jit, static_argnums=(1,))
154
+ def unique(x, size, fill_value):
155
+ """
156
+ Restricted version of `jax.numpy.unique` that uses less memory.
157
+
158
+ Parameters
159
+ ----------
160
+ x : 1d array
161
+ The input array.
162
+ size : int
163
+ The length of the output.
164
+ fill_value : scalar
165
+ The value to fill the output with if `size` is greater than the number
166
+ of unique values in `x`.
167
+
168
+ Returns
169
+ -------
170
+ out : array (size,)
171
+ The unique values in `x`, sorted, and right-padded with `fill_value`.
172
+ actual_length : int
173
+ The number of used values in `out`.
174
+ """
175
+ if x.size == 0:
176
+ return jnp.full(size, fill_value, x.dtype), 0
177
+ if size == 0:
178
+ return jnp.empty(0, x.dtype), 0
179
+ x = jnp.sort(x)
180
+ def loop(carry, x):
181
+ i_out, i_in, last, out = carry
182
+ i_out = jnp.where(x == last, i_out, i_out + 1)
183
+ out = out.at[i_out].set(x)
184
+ return (i_out, i_in + 1, x, out), None
185
+ carry = 0, 0, x[0], jnp.full(size, fill_value, x.dtype)
186
+ (actual_length, _, _, out), _ = jax.lax.scan(loop, carry, x[:size])
187
+ return out, actual_length + 1
188
+
189
+ def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False):
190
+ """
191
+ Batch a function such that each batch is smaller than a threshold.
192
+
193
+ Parameters
194
+ ----------
195
+ func : callable
196
+ A jittable function with positional arguments only, with inputs and
197
+ outputs pytrees of arrays.
198
+ max_io_nbytes : int
199
+ The maximum number of input + output bytes in each batch.
200
+ in_axes : pytree of ints, default 0
201
+ A tree matching the structure of the function input, indicating along
202
+ which axes each array should be batched. If a single integer, it is
203
+ used for all arrays.
204
+ out_axes : pytree of ints, default 0
205
+ The same for outputs.
206
+ return_nbatches : bool, default False
207
+ If True, the number of batches is returned as a second output.
208
+
209
+ Returns
210
+ -------
211
+ batched_func : callable
212
+ A function with the same signature as `func`, but that processes the
213
+ input and output in batches in a loop.
214
+ """
215
+
216
+ def expand_axes(axes, tree):
217
+ if isinstance(axes, int):
218
+ return tree_util.tree_map(lambda _: axes, tree)
219
+ return tree_util.tree_map(lambda _, axis: axis, tree, axes)
220
+
221
+ def extract_size(axes, tree):
222
+ sizes = tree_util.tree_map(lambda x, axis: x.shape[axis], tree, axes)
223
+ sizes, _ = tree_util.tree_flatten(sizes)
224
+ assert all(s == sizes[0] for s in sizes)
225
+ return sizes[0]
226
+
227
+ def sum_nbytes(tree):
228
+ def nbytes(x):
229
+ return math.prod(x.shape) * x.dtype.itemsize
230
+ return tree_util.tree_reduce(lambda size, x: size + nbytes(x), tree, 0)
231
+
232
+ def next_divisor_small(dividend, min_divisor):
233
+ for divisor in range(min_divisor, int(math.sqrt(dividend)) + 1):
234
+ if dividend % divisor == 0:
235
+ return divisor
236
+ return dividend
237
+
238
+ def next_divisor_large(dividend, min_divisor):
239
+ max_inv_divisor = dividend // min_divisor
240
+ for inv_divisor in range(max_inv_divisor, 0, -1):
241
+ if dividend % inv_divisor == 0:
242
+ return dividend // inv_divisor
243
+ return dividend
244
+
245
+ def next_divisor(dividend, min_divisor):
246
+ if min_divisor * min_divisor <= dividend:
247
+ return next_divisor_small(dividend, min_divisor)
248
+ return next_divisor_large(dividend, min_divisor)
249
+
250
+ def move_axes_out(axes, tree):
251
+ def move_axis_out(axis, x):
252
+ if axis != 0:
253
+ return jnp.moveaxis(x, axis, 0)
254
+ return x
255
+ return tree_util.tree_map(move_axis_out, axes, tree)
256
+
257
+ def move_axes_in(axes, tree):
258
+ def move_axis_in(axis, x):
259
+ if axis != 0:
260
+ return jnp.moveaxis(x, 0, axis)
261
+ return x
262
+ return tree_util.tree_map(move_axis_in, axes, tree)
263
+
264
+ def batch(tree, nbatches):
265
+ def batch(x):
266
+ return x.reshape((nbatches, x.shape[0] // nbatches) + x.shape[1:])
267
+ return tree_util.tree_map(batch, tree)
268
+
269
+ def unbatch(tree):
270
+ def unbatch(x):
271
+ return x.reshape((x.shape[0] * x.shape[1],) + x.shape[2:])
272
+ return tree_util.tree_map(unbatch, tree)
273
+
274
+ def check_same(tree1, tree2):
275
+ def check_same(x1, x2):
276
+ assert x1.shape == x2.shape
277
+ assert x1.dtype == x2.dtype
278
+ tree_util.tree_map(check_same, tree1, tree2)
279
+
280
+ initial_in_axes = in_axes
281
+ initial_out_axes = out_axes
282
+
283
+ @jax.jit
284
+ @functools.wraps(func)
285
+ def batched_func(*args):
286
+ example_result = jax.eval_shape(func, *args)
287
+
288
+ in_axes = expand_axes(initial_in_axes, args)
289
+ out_axes = expand_axes(initial_out_axes, example_result)
290
+
291
+ in_size = extract_size(in_axes, args)
292
+ out_size = extract_size(out_axes, example_result)
293
+ assert in_size == out_size
294
+ size = in_size
295
+
296
+ total_nbytes = sum_nbytes(args) + sum_nbytes(example_result)
297
+ min_nbatches = total_nbytes // max_io_nbytes + bool(total_nbytes % max_io_nbytes)
298
+ nbatches = next_divisor(size, min_nbatches)
299
+ assert 1 <= nbatches <= size
300
+ assert size % nbatches == 0
301
+ assert total_nbytes % nbatches == 0
302
+
303
+ batch_nbytes = total_nbytes // nbatches
304
+ if batch_nbytes > max_io_nbytes:
305
+ assert size == nbatches
306
+ warnings.warn(f'batch_nbytes = {batch_nbytes} > max_io_nbytes = {max_io_nbytes}')
307
+
308
+ def loop(_, args):
309
+ args = move_axes_in(in_axes, args)
310
+ result = func(*args)
311
+ result = move_axes_out(out_axes, result)
312
+ return None, result
313
+
314
+ args = move_axes_out(in_axes, args)
315
+ args = batch(args, nbatches)
316
+ _, result = lax.scan(loop, None, args)
317
+ result = unbatch(result)
318
+ result = move_axes_in(out_axes, result)
319
+
320
+ check_same(example_result, result)
321
+
322
+ if return_nbatches:
323
+ return result, nbatches
324
+ return result
325
+
326
+ return batched_func
327
+
328
+ @tree_util.register_pytree_node_class
329
+ class LeafDict(dict):
330
+ """ dictionary that acts as a leaf in jax pytrees, to store compile-time
331
+ values """
332
+
333
+ def tree_flatten(self):
334
+ return (), self
335
+
336
+ @classmethod
337
+ def tree_unflatten(cls, aux_data, children):
338
+ return aux_data
339
+
340
+ def __repr__(self):
341
+ return f'{__class__.__name__}({super().__repr__()})'