bartz 0.5.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bartz/prepcovars.py CHANGED
@@ -22,64 +22,113 @@
22
22
  # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23
23
  # SOFTWARE.
24
24
 
25
- import functools
25
+ """Functions to preprocess data."""
26
26
 
27
- import jax
27
+ from functools import partial
28
+
29
+ from jax import jit, vmap
28
30
  from jax import numpy as jnp
31
+ from jaxtyping import Array, Float, Integer, Real, UInt
32
+
33
+ from bartz.jaxext import autobatch, minimal_unsigned_dtype, unique
34
+
29
35
 
30
- from . import jaxext
36
+ def parse_xinfo(
37
+ xinfo: Float[Array, 'p m'],
38
+ ) -> tuple[Float[Array, 'p m'], UInt[Array, ' p']]:
39
+ """Parse pre-defined splits in the format of the R package BART.
40
+
41
+ Parameters
42
+ ----------
43
+ xinfo
44
+ A matrix with the cutpoins to use to bin each predictor. Each row shall
45
+ contain a sorted list of cutpoints for a predictor. If there are less
46
+ cutpoints than the number of columns in the matrix, fill the remaining
47
+ cells with NaN.
48
+
49
+ `xinfo` shall be a matrix even if `x_train` is a dataframe.
50
+
51
+ Returns
52
+ -------
53
+ splits : Float[Array, 'p m']
54
+ `xinfo` modified by replacing nan with a large value.
55
+ max_split : UInt[Array, 'p']
56
+ The number of non-nan elements in each row of `xinfo`.
57
+ """
58
+ is_not_nan = ~jnp.isnan(xinfo)
59
+ max_split = jnp.sum(is_not_nan, axis=1)
60
+ max_split = max_split.astype(minimal_unsigned_dtype(xinfo.shape[1]))
61
+ huge = _huge_value(xinfo)
62
+ splits = jnp.where(is_not_nan, xinfo, huge)
63
+ return splits, max_split
31
64
 
32
65
 
33
- @functools.partial(jax.jit, static_argnums=(1,))
34
- def quantilized_splits_from_matrix(X, max_bins):
66
+ @partial(jit, static_argnums=(1,))
67
+ def quantilized_splits_from_matrix(
68
+ X: Real[Array, 'p n'], max_bins: int
69
+ ) -> tuple[Real[Array, 'p m'], UInt[Array, ' p']]:
35
70
  """
36
71
  Determine bins that make the distribution of each predictor uniform.
37
72
 
38
73
  Parameters
39
74
  ----------
40
- X : array (p, n)
75
+ X
41
76
  A matrix with `p` predictors and `n` observations.
42
- max_bins : int
77
+ max_bins
43
78
  The maximum number of bins to produce.
44
79
 
45
80
  Returns
46
81
  -------
47
- splits : array (p, m)
82
+ splits : Real[Array, 'p m']
48
83
  A matrix containing, for each predictor, the boundaries between bins.
49
84
  `m` is ``min(max_bins, n) - 1``, which is an upper bound on the number
50
85
  of splits. Each predictor may have a different number of splits; unused
51
86
  values at the end of each row are filled with the maximum value
52
87
  representable in the type of `X`.
53
- max_split : array (p,)
88
+ max_split : UInt[Array, ' p']
54
89
  The number of actually used values in each row of `splits`.
90
+
91
+ Raises
92
+ ------
93
+ ValueError
94
+ If `X` has no columns or if `max_bins` is less than 1.
55
95
  """
56
96
  out_length = min(max_bins, X.shape[1]) - 1
57
97
 
58
- # return _quantilized_splits_from_matrix(X, out_length)
59
- @functools.partial(jaxext.autobatch, max_io_nbytes=2**29)
98
+ if out_length < 0:
99
+ msg = f'{X.shape[1]=} and {max_bins=}, they should be both at least 1.'
100
+ raise ValueError(msg)
101
+
102
+ @partial(autobatch, max_io_nbytes=2**29)
60
103
  def quantilize(X):
104
+ # wrap this function because autobatch needs traceable args
61
105
  return _quantilized_splits_from_matrix(X, out_length)
62
106
 
63
107
  return quantilize(X)
64
108
 
65
109
 
66
- @functools.partial(jax.vmap, in_axes=(0, None))
67
- def _quantilized_splits_from_matrix(x, out_length):
68
- huge = jaxext.huge_value(x)
69
- u, actual_length = jaxext.unique(x, size=x.size, fill_value=huge)
70
- actual_length -= 1
110
+ @partial(vmap, in_axes=(0, None))
111
+ def _quantilized_splits_from_matrix(
112
+ x: Real[Array, 'p n'], out_length: int
113
+ ) -> tuple[Real[Array, 'p m'], UInt[Array, ' p']]:
114
+ # find the sorted unique values in x
115
+ huge = _huge_value(x)
116
+ u, actual_length = unique(x, size=x.size, fill_value=huge)
117
+
118
+ # compute the midpoints between each unique value
71
119
  if jnp.issubdtype(x.dtype, jnp.integer):
72
- midpoints = u[:-1] + jaxext.ensure_unsigned(u[1:] - u[:-1]) // 2
73
- indices = jnp.arange(
74
- midpoints.size, dtype=jaxext.minimal_unsigned_dtype(midpoints.size - 1)
75
- )
76
- midpoints = jnp.where(indices < actual_length, midpoints, huge)
120
+ midpoints = u[:-1] + _ensure_unsigned(u[1:] - u[:-1]) // 2
77
121
  else:
78
- midpoints = (u[1:] + u[:-1]) / 2
122
+ midpoints = u[:-1] + (u[1:] - u[:-1]) / 2
123
+ # using x_i + (x_i+1 - x_i) / 2 instead of (x_i + x_i+1) / 2 is to
124
+ # avoid overflow
125
+ actual_length -= 1
126
+ if midpoints.size:
127
+ midpoints = midpoints.at[actual_length].set(huge)
128
+
129
+ # take a subset of the midpoints if there are more than the requested maximum
79
130
  indices = jnp.linspace(-1, actual_length, out_length + 2)[1:-1]
80
- indices = jnp.around(indices).astype(
81
- jaxext.minimal_unsigned_dtype(midpoints.size - 1)
82
- )
131
+ indices = jnp.around(indices).astype(minimal_unsigned_dtype(midpoints.size - 1))
83
132
  # indices calculation with float rather than int to avoid potential
84
133
  # overflow with int32, and to round to nearest instead of rounding down
85
134
  decimated_midpoints = midpoints[indices]
@@ -88,41 +137,92 @@ def _quantilized_splits_from_matrix(x, out_length):
88
137
  actual_length > out_length, decimated_midpoints, truncated_midpoints
89
138
  )
90
139
  max_split = jnp.minimum(actual_length, out_length)
91
- max_split = max_split.astype(jaxext.minimal_unsigned_dtype(out_length))
140
+ max_split = max_split.astype(minimal_unsigned_dtype(out_length))
92
141
  return splits, max_split
93
142
 
94
143
 
95
- @functools.partial(jax.jit, static_argnums=(1,))
96
- def uniform_splits_from_matrix(X, num_bins):
144
+ def _huge_value(x: Array) -> int | float:
145
+ """
146
+ Return the maximum value that can be stored in `x`.
147
+
148
+ Parameters
149
+ ----------
150
+ x
151
+ A numerical numpy or jax array.
152
+
153
+ Returns
154
+ -------
155
+ The maximum value allowed by `x`'s type (finite for floats).
156
+ """
157
+ if jnp.issubdtype(x.dtype, jnp.integer):
158
+ return jnp.iinfo(x.dtype).max
159
+ else:
160
+ return float(jnp.finfo(x.dtype).max)
161
+
162
+
163
+ def _ensure_unsigned(x: Integer[Array, '*shape']) -> UInt[Array, '*shape']:
164
+ """If x has signed integer type, cast it to the unsigned dtype of the same size."""
165
+ return x.astype(_signed_to_unsigned(x.dtype))
166
+
167
+
168
+ def _signed_to_unsigned(int_dtype: jnp.dtype) -> jnp.dtype:
169
+ """
170
+ Map a signed integer type to its unsigned counterpart.
171
+
172
+ Unsigned types are passed through.
173
+ """
174
+ assert jnp.issubdtype(int_dtype, jnp.integer)
175
+ if jnp.issubdtype(int_dtype, jnp.unsignedinteger):
176
+ return int_dtype
177
+ match int_dtype:
178
+ case jnp.int8:
179
+ return jnp.uint8
180
+ case jnp.int16:
181
+ return jnp.uint16
182
+ case jnp.int32:
183
+ return jnp.uint32
184
+ case jnp.int64:
185
+ return jnp.uint64
186
+ case _:
187
+ msg = f'unexpected integer type {int_dtype}'
188
+ raise TypeError(msg)
189
+
190
+
191
+ @partial(jit, static_argnums=(1,))
192
+ def uniform_splits_from_matrix(
193
+ X: Real[Array, 'p n'], num_bins: int
194
+ ) -> tuple[Real[Array, 'p m'], UInt[Array, ' p']]:
97
195
  """
98
196
  Make an evenly spaced binning grid.
99
197
 
100
198
  Parameters
101
199
  ----------
102
- X : array (p, n)
200
+ X
103
201
  A matrix with `p` predictors and `n` observations.
104
- num_bins : int
202
+ num_bins
105
203
  The number of bins to produce.
106
204
 
107
205
  Returns
108
206
  -------
109
- splits : array (p, num_bins - 1)
207
+ splits : Real[Array, 'p m']
110
208
  A matrix containing, for each predictor, the boundaries between bins.
111
209
  The excluded endpoints are the minimum and maximum value in each row of
112
210
  `X`.
113
- max_split : array (p,)
211
+ max_split : UInt[Array, ' p']
114
212
  The number of cutpoints in each row of `splits`, i.e., ``num_bins - 1``.
115
213
  """
116
214
  low = jnp.min(X, axis=1)
117
215
  high = jnp.max(X, axis=1)
118
216
  splits = jnp.linspace(low, high, num_bins + 1, axis=1)[:, 1:-1]
119
217
  assert splits.shape == (X.shape[0], num_bins - 1)
120
- max_split = jnp.full(*splits.shape, jaxext.minimal_unsigned_dtype(num_bins - 1))
218
+ max_split = jnp.full(*splits.shape, minimal_unsigned_dtype(num_bins - 1))
121
219
  return splits, max_split
122
220
 
123
221
 
124
- @functools.partial(jax.jit, static_argnames=('method',))
125
- def bin_predictors(X, splits, **kw):
222
+ @partial(jit, static_argnames=('method',))
223
+ def bin_predictors(
224
+ X: Real[Array, 'p n'], splits: Real[Array, 'p m'], **kw
225
+ ) -> UInt[Array, 'p n']:
126
226
  """
127
227
  Bin the predictors according to the given splits.
128
228
 
@@ -130,27 +230,25 @@ def bin_predictors(X, splits, **kw):
130
230
 
131
231
  Parameters
132
232
  ----------
133
- X : array (p, n)
233
+ X
134
234
  A matrix with `p` predictors and `n` observations.
135
- splits : array (p, m)
235
+ splits
136
236
  A matrix containing, for each predictor, the boundaries between bins.
137
237
  `m` is the maximum number of splits; each row may have shorter
138
238
  actual length, marked by padding unused locations at the end of the
139
239
  row with the maximum value allowed by the type.
140
- **kw : dict
240
+ **kw
141
241
  Additional arguments are passed to `jax.numpy.searchsorted`.
142
242
 
143
243
  Returns
144
244
  -------
145
- X_binned : int array (p, n)
146
- A matrix with `p` predictors and `n` observations, where each predictor
147
- has been replaced by the index of the bin it falls into.
245
+ `X` but with each value replaced by the index of the bin it falls into.
148
246
  """
149
247
 
150
- @functools.partial(jaxext.autobatch, max_io_nbytes=2**29)
151
- @jax.vmap
248
+ @partial(autobatch, max_io_nbytes=2**29)
249
+ @vmap
152
250
  def bin_predictors(x, splits):
153
- dtype = jaxext.minimal_unsigned_dtype(splits.size)
251
+ dtype = minimal_unsigned_dtype(splits.size)
154
252
  return jnp.searchsorted(splits, x, **kw).astype(dtype)
155
253
 
156
254
  return bin_predictors(X, splits)
@@ -1,14 +1,15 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: bartz
3
- Version: 0.5.0
3
+ Version: 0.7.0
4
4
  Summary: Super-fast BART (Bayesian Additive Regression Trees) in Python
5
5
  Author: Giacomo Petrillo
6
6
  Author-email: Giacomo Petrillo <info@giacomopetrillo.com>
7
7
  License-Expression: MIT
8
- Requires-Dist: jax>=0.4.35,<1
9
- Requires-Dist: jaxlib>=0.4.35,<1
10
- Requires-Dist: numpy>=1.25.2,<3
11
- Requires-Dist: scipy>=1.11.4,<2
8
+ Requires-Dist: equinox>=0.12.2
9
+ Requires-Dist: jax>=0.5.3
10
+ Requires-Dist: jaxtyping>=0.3.2
11
+ Requires-Dist: numpy>=1.25.2
12
+ Requires-Dist: scipy>=1.11.4
12
13
  Requires-Python: >=3.10
13
14
  Project-URL: Documentation, https://gattocrucco.github.io/bartz/docs-dev
14
15
  Project-URL: Homepage, https://github.com/Gattocrucco/bartz
@@ -0,0 +1,17 @@
1
+ bartz/.DS_Store,sha256=7191af46d7b8c0d4c03c502f94eb01353bc2e615d75c45b3af0e31ab238034b5,6148
2
+ bartz/BART.py,sha256=6b129e20a258d724e0cba5ffaf377b5c4d62e545c1f3a737602c2a0be5d84b96,29601
3
+ bartz/__init__.py,sha256=98c579136a8755390210ada33e713290749e4d7fca58550c791f77f192f4b4a1,1436
4
+ bartz/_version.py,sha256=d3d868979a2f2fa02b20b248259f8c8ac7273a329ae918da1139ad7602695b67,22
5
+ bartz/debug.py,sha256=5082c2dd07f6d3f8353c57491c857cca4aea913647e24502cc375f9aadc84975,43736
6
+ bartz/grove.py,sha256=f64505623feec7edcec96930909f9a9326f290976d7e9bdb7b9abcb32fe425fe,10559
7
+ bartz/jaxext/__init__.py,sha256=6cd2e7c23ccc4f0399fb3f7989312a500fd8bd3f7f07eb85ce537f2c8873f35a,6705
8
+ bartz/jaxext/_autobatch.py,sha256=b5dbaec52e39b4b32c824fde47eaaf33f496a1574cc336d3a79fa71f4c5e348a,7116
9
+ bartz/jaxext/scipy/__init__.py,sha256=a1f5990a75c1c73908565be4cd5fa1c07278ad3a02b78b21e2f1225b388ab6b5,1227
10
+ bartz/jaxext/scipy/special.py,sha256=f0e777c29a77d46d55ff2f0169b3d8581de24d8580b8d6f1c7c1473910fb62d9,8111
11
+ bartz/jaxext/scipy/stats.py,sha256=703beb9fcfe606a195fe9a3143eceed367c87a31bd75763ecdb6e509c2e87f53,1483
12
+ bartz/mcmcloop.py,sha256=e841c076bf5392b1ad8eab955c70a63d18ac321c249d255c7424b62e52c127c1,22134
13
+ bartz/mcmcstep.py,sha256=3b2a4bdad3bbf836efc7cc6f52adac3ba23dce00a1cca4a99e5a8d79d10c3e68,84092
14
+ bartz/prepcovars.py,sha256=50334621ca6ec7a6e35d21ce5ff8b96d4cbf0e696c9bcd10d86268e301a820a7,8728
15
+ bartz-0.7.0.dist-info/WHEEL,sha256=607c46fee47e440c91332c738096ff0f5e54ca3b0818ee85462dd5172a38e793,79
16
+ bartz-0.7.0.dist-info/METADATA,sha256=758e12296acf815c9bae50601ff502e6dee97c9ac0d817232e45e3dfb39665bc,2815
17
+ bartz-0.7.0.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: uv 0.7.4
2
+ Generator: uv 0.7.19
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
bartz/jaxext.py DELETED
@@ -1,374 +0,0 @@
1
- # bartz/src/bartz/jaxext.py
2
- #
3
- # Copyright (c) 2024-2025, Giacomo Petrillo
4
- #
5
- # This file is part of bartz.
6
- #
7
- # Permission is hereby granted, free of charge, to any person obtaining a copy
8
- # of this software and associated documentation files (the "Software"), to deal
9
- # in the Software without restriction, including without limitation the rights
10
- # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11
- # copies of the Software, and to permit persons to whom the Software is
12
- # furnished to do so, subject to the following conditions:
13
- #
14
- # The above copyright notice and this permission notice shall be included in all
15
- # copies or substantial portions of the Software.
16
- #
17
- # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18
- # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19
- # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20
- # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21
- # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22
- # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23
- # SOFTWARE.
24
-
25
- import functools
26
- import math
27
- import warnings
28
-
29
- import jax
30
- from jax import lax, tree_util
31
- from jax import numpy as jnp
32
- from scipy import special
33
-
34
-
35
- def float_type(*args):
36
- """
37
- Determine the jax floating point result type given operands/types.
38
- """
39
- t = jnp.result_type(*args)
40
- return jnp.sin(jnp.empty(0, t)).dtype
41
-
42
-
43
- def castto(func, type):
44
- @functools.wraps(func)
45
- def newfunc(*args, **kw):
46
- return func(*args, **kw).astype(type)
47
-
48
- return newfunc
49
-
50
-
51
- class scipy:
52
- class special:
53
- @functools.wraps(special.gammainccinv)
54
- def gammainccinv(a, y):
55
- a = jnp.asarray(a)
56
- y = jnp.asarray(y)
57
- shape = jnp.broadcast_shapes(a.shape, y.shape)
58
- dtype = float_type(a.dtype, y.dtype)
59
- dummy = jax.ShapeDtypeStruct(shape, dtype)
60
- ufunc = castto(special.gammainccinv, dtype)
61
- return jax.pure_callback(ufunc, dummy, a, y, vmap_method='expand_dims')
62
-
63
- class stats:
64
- class invgamma:
65
- def ppf(q, a):
66
- return 1 / scipy.special.gammainccinv(a, q)
67
-
68
-
69
- def vmap_nodoc(fun, *args, **kw):
70
- """
71
- Wrapper of `jax.vmap` that preserves the docstring of the input function.
72
-
73
- This is useful if the docstring already takes into account that the
74
- arguments have additional axes due to vmap.
75
- """
76
- doc = fun.__doc__
77
- fun = jax.vmap(fun, *args, **kw)
78
- fun.__doc__ = doc
79
- return fun
80
-
81
-
82
- def huge_value(x):
83
- """
84
- Return the maximum value that can be stored in `x`.
85
-
86
- Parameters
87
- ----------
88
- x : array
89
- A numerical numpy or jax array.
90
-
91
- Returns
92
- -------
93
- maxval : scalar
94
- The maximum value allowed by `x`'s type (+inf for floats).
95
- """
96
- if jnp.issubdtype(x.dtype, jnp.integer):
97
- return jnp.iinfo(x.dtype).max
98
- else:
99
- return jnp.inf
100
-
101
-
102
- def minimal_unsigned_dtype(max_value):
103
- """
104
- Return the smallest unsigned integer dtype that can represent a given
105
- maximum value (inclusive).
106
- """
107
- if max_value < 2**8:
108
- return jnp.uint8
109
- if max_value < 2**16:
110
- return jnp.uint16
111
- if max_value < 2**32:
112
- return jnp.uint32
113
- return jnp.uint64
114
-
115
-
116
- def signed_to_unsigned(int_dtype):
117
- """
118
- Map a signed integer type to its unsigned counterpart. Unsigned types are
119
- passed through.
120
- """
121
- assert jnp.issubdtype(int_dtype, jnp.integer)
122
- if jnp.issubdtype(int_dtype, jnp.unsignedinteger):
123
- return int_dtype
124
- if int_dtype == jnp.int8:
125
- return jnp.uint8
126
- if int_dtype == jnp.int16:
127
- return jnp.uint16
128
- if int_dtype == jnp.int32:
129
- return jnp.uint32
130
- if int_dtype == jnp.int64:
131
- return jnp.uint64
132
-
133
-
134
- def ensure_unsigned(x):
135
- """
136
- If x has signed integer type, cast it to the unsigned dtype of the same size.
137
- """
138
- return x.astype(signed_to_unsigned(x.dtype))
139
-
140
-
141
- @functools.partial(jax.jit, static_argnums=(1,))
142
- def unique(x, size, fill_value):
143
- """
144
- Restricted version of `jax.numpy.unique` that uses less memory.
145
-
146
- Parameters
147
- ----------
148
- x : 1d array
149
- The input array.
150
- size : int
151
- The length of the output.
152
- fill_value : scalar
153
- The value to fill the output with if `size` is greater than the number
154
- of unique values in `x`.
155
-
156
- Returns
157
- -------
158
- out : array (size,)
159
- The unique values in `x`, sorted, and right-padded with `fill_value`.
160
- actual_length : int
161
- The number of used values in `out`.
162
- """
163
- if x.size == 0:
164
- return jnp.full(size, fill_value, x.dtype), 0
165
- if size == 0:
166
- return jnp.empty(0, x.dtype), 0
167
- x = jnp.sort(x)
168
-
169
- def loop(carry, x):
170
- i_out, i_in, last, out = carry
171
- i_out = jnp.where(x == last, i_out, i_out + 1)
172
- out = out.at[i_out].set(x)
173
- return (i_out, i_in + 1, x, out), None
174
-
175
- carry = 0, 0, x[0], jnp.full(size, fill_value, x.dtype)
176
- (actual_length, _, _, out), _ = jax.lax.scan(loop, carry, x[:size])
177
- return out, actual_length + 1
178
-
179
-
180
- def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False):
181
- """
182
- Batch a function such that each batch is smaller than a threshold.
183
-
184
- Parameters
185
- ----------
186
- func : callable
187
- A jittable function with positional arguments only, with inputs and
188
- outputs pytrees of arrays.
189
- max_io_nbytes : int
190
- The maximum number of input + output bytes in each batch (excluding
191
- unbatched arguments.)
192
- in_axes : pytree of int or None, default 0
193
- A tree matching the structure of the function input, indicating along
194
- which axes each array should be batched. If a single integer, it is
195
- used for all arrays. A `None` axis indicates to not batch an argument.
196
- out_axes : pytree of ints, default 0
197
- The same for outputs (but non-batching is not allowed).
198
- return_nbatches : bool, default False
199
- If True, the number of batches is returned as a second output.
200
-
201
- Returns
202
- -------
203
- batched_func : callable
204
- A function with the same signature as `func`, but that processes the
205
- input and output in batches in a loop.
206
- """
207
-
208
- def expand_axes(axes, tree):
209
- if isinstance(axes, int):
210
- return tree_util.tree_map(lambda _: axes, tree)
211
- return tree_util.tree_map(lambda _, axis: axis, tree, axes)
212
-
213
- def check_no_nones(axes, tree):
214
- def check_not_none(_, axis):
215
- assert axis is not None
216
-
217
- tree_util.tree_map(check_not_none, tree, axes)
218
-
219
- def extract_size(axes, tree):
220
- def get_size(x, axis):
221
- if axis is None:
222
- return None
223
- else:
224
- return x.shape[axis]
225
-
226
- sizes = tree_util.tree_map(get_size, tree, axes)
227
- sizes, _ = tree_util.tree_flatten(sizes)
228
- assert all(s == sizes[0] for s in sizes)
229
- return sizes[0]
230
-
231
- def sum_nbytes(tree):
232
- def nbytes(x):
233
- return math.prod(x.shape) * x.dtype.itemsize
234
-
235
- return tree_util.tree_reduce(lambda size, x: size + nbytes(x), tree, 0)
236
-
237
- def next_divisor_small(dividend, min_divisor):
238
- for divisor in range(min_divisor, int(math.sqrt(dividend)) + 1):
239
- if dividend % divisor == 0:
240
- return divisor
241
- return dividend
242
-
243
- def next_divisor_large(dividend, min_divisor):
244
- max_inv_divisor = dividend // min_divisor
245
- for inv_divisor in range(max_inv_divisor, 0, -1):
246
- if dividend % inv_divisor == 0:
247
- return dividend // inv_divisor
248
- return dividend
249
-
250
- def next_divisor(dividend, min_divisor):
251
- if dividend == 0:
252
- return min_divisor
253
- if min_divisor * min_divisor <= dividend:
254
- return next_divisor_small(dividend, min_divisor)
255
- return next_divisor_large(dividend, min_divisor)
256
-
257
- def pull_nonbatched(axes, tree):
258
- def pull_nonbatched(x, axis):
259
- if axis is None:
260
- return None
261
- else:
262
- return x
263
-
264
- return tree_util.tree_map(pull_nonbatched, tree, axes), tree
265
-
266
- def push_nonbatched(axes, tree, original_tree):
267
- def push_nonbatched(original_x, x, axis):
268
- if axis is None:
269
- return original_x
270
- else:
271
- return x
272
-
273
- return tree_util.tree_map(push_nonbatched, original_tree, tree, axes)
274
-
275
- def move_axes_out(axes, tree):
276
- def move_axis_out(x, axis):
277
- return jnp.moveaxis(x, axis, 0)
278
-
279
- return tree_util.tree_map(move_axis_out, tree, axes)
280
-
281
- def move_axes_in(axes, tree):
282
- def move_axis_in(x, axis):
283
- return jnp.moveaxis(x, 0, axis)
284
-
285
- return tree_util.tree_map(move_axis_in, tree, axes)
286
-
287
- def batch(tree, nbatches):
288
- def batch(x):
289
- return x.reshape((nbatches, x.shape[0] // nbatches) + x.shape[1:])
290
-
291
- return tree_util.tree_map(batch, tree)
292
-
293
- def unbatch(tree):
294
- def unbatch(x):
295
- return x.reshape((x.shape[0] * x.shape[1],) + x.shape[2:])
296
-
297
- return tree_util.tree_map(unbatch, tree)
298
-
299
- def check_same(tree1, tree2):
300
- def check_same(x1, x2):
301
- assert x1.shape == x2.shape
302
- assert x1.dtype == x2.dtype
303
-
304
- tree_util.tree_map(check_same, tree1, tree2)
305
-
306
- initial_in_axes = in_axes
307
- initial_out_axes = out_axes
308
-
309
- @jax.jit
310
- @functools.wraps(func)
311
- def batched_func(*args):
312
- example_result = jax.eval_shape(func, *args)
313
-
314
- in_axes = expand_axes(initial_in_axes, args)
315
- out_axes = expand_axes(initial_out_axes, example_result)
316
- check_no_nones(out_axes, example_result)
317
-
318
- size = extract_size((in_axes, out_axes), (args, example_result))
319
-
320
- args, nonbatched_args = pull_nonbatched(in_axes, args)
321
-
322
- total_nbytes = sum_nbytes((args, example_result))
323
- min_nbatches = total_nbytes // max_io_nbytes + bool(
324
- total_nbytes % max_io_nbytes
325
- )
326
- min_nbatches = max(1, min_nbatches)
327
- nbatches = next_divisor(size, min_nbatches)
328
- assert 1 <= nbatches <= max(1, size)
329
- assert size % nbatches == 0
330
- assert total_nbytes % nbatches == 0
331
-
332
- batch_nbytes = total_nbytes // nbatches
333
- if batch_nbytes > max_io_nbytes:
334
- assert size == nbatches
335
- warnings.warn(
336
- f'batch_nbytes = {batch_nbytes} > max_io_nbytes = {max_io_nbytes}'
337
- )
338
-
339
- def loop(_, args):
340
- args = move_axes_in(in_axes, args)
341
- args = push_nonbatched(in_axes, args, nonbatched_args)
342
- result = func(*args)
343
- result = move_axes_out(out_axes, result)
344
- return None, result
345
-
346
- args = move_axes_out(in_axes, args)
347
- args = batch(args, nbatches)
348
- _, result = lax.scan(loop, None, args)
349
- result = unbatch(result)
350
- result = move_axes_in(out_axes, result)
351
-
352
- check_same(example_result, result)
353
-
354
- if return_nbatches:
355
- return result, nbatches
356
- return result
357
-
358
- return batched_func
359
-
360
-
361
- @tree_util.register_pytree_node_class
362
- class LeafDict(dict):
363
- """dictionary that acts as a leaf in jax pytrees, to store compile-time
364
- values"""
365
-
366
- def tree_flatten(self):
367
- return (), self
368
-
369
- @classmethod
370
- def tree_unflatten(cls, aux_data, children):
371
- return aux_data
372
-
373
- def __repr__(self):
374
- return f'{__class__.__name__}({super().__repr__()})'