bartz 0.4.1__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bartz/debug.py CHANGED
@@ -1,21 +1,19 @@
1
1
  import functools
2
2
 
3
3
  import jax
4
- from jax import numpy as jnp
5
4
  from jax import lax
5
+ from jax import numpy as jnp
6
6
 
7
- from . import grove
8
- from . import mcmcstep
9
- from . import jaxext
7
+ from . import grove, jaxext
10
8
 
11
- def print_tree(leaf_tree, var_tree, split_tree, print_all=False):
12
9
 
10
+ def print_tree(leaf_tree, var_tree, split_tree, print_all=False):
13
11
  tee = '├──'
14
12
  corner = '└──'
15
13
  join = '│ '
16
14
  space = ' '
17
15
  down = '┐'
18
- bottom = '╢' # '┨' #
16
+ bottom = '╢' # '┨' #
19
17
 
20
18
  def traverse_tree(index, depth, indent, first_indent, next_indent, unused):
21
19
  if index >= len(leaf_tree):
@@ -58,7 +56,7 @@ def print_tree(leaf_tree, var_tree, split_tree, print_all=False):
58
56
 
59
57
  indent += next_indent
60
58
  unused = unused or is_leaf
61
-
59
+
62
60
  if unused and not print_all:
63
61
  return
64
62
 
@@ -67,58 +65,80 @@ def print_tree(leaf_tree, var_tree, split_tree, print_all=False):
67
65
 
68
66
  traverse_tree(1, 0, '', '', '', False)
69
67
 
68
+
70
69
  def tree_actual_depth(split_tree):
71
70
  is_leaf = grove.is_actual_leaf(split_tree, add_bottom_level=True)
72
71
  depth = grove.tree_depths(is_leaf.size)
73
72
  depth = jnp.where(is_leaf, depth, 0)
74
73
  return jnp.max(depth)
75
74
 
75
+
76
76
  def forest_depth_distr(split_trees):
77
77
  depth = grove.tree_depth(split_trees) + 1
78
78
  depths = jax.vmap(tree_actual_depth)(split_trees)
79
79
  return jnp.bincount(depths, length=depth)
80
80
 
81
+
81
82
  def trace_depth_distr(split_trees_trace):
82
83
  return jax.vmap(forest_depth_distr)(split_trees_trace)
83
84
 
85
+
84
86
  def points_per_leaf_distr(var_tree, split_tree, X):
85
87
  traverse_tree = jax.vmap(grove.traverse_tree, in_axes=(1, None, None))
86
88
  indices = traverse_tree(X, var_tree, split_tree)
87
- count_tree = jnp.zeros(2 * split_tree.size, dtype=jaxext.minimal_unsigned_dtype(indices.size))
89
+ count_tree = jnp.zeros(
90
+ 2 * split_tree.size, dtype=jaxext.minimal_unsigned_dtype(indices.size)
91
+ )
88
92
  count_tree = count_tree.at[indices].add(1)
89
93
  is_leaf = grove.is_actual_leaf(split_tree, add_bottom_level=True).view(jnp.uint8)
90
94
  return jnp.bincount(count_tree, is_leaf, length=X.shape[1] + 1)
91
95
 
96
+
92
97
  def forest_points_per_leaf_distr(bart, X):
93
98
  distr = jnp.zeros(X.shape[1] + 1, int)
94
99
  trees = bart['var_trees'], bart['split_trees']
100
+
95
101
  def loop(distr, tree):
96
102
  return distr + points_per_leaf_distr(*tree, X), None
103
+
97
104
  distr, _ = lax.scan(loop, distr, trees)
98
105
  return distr
99
106
 
107
+
100
108
  def trace_points_per_leaf_distr(bart, X):
101
109
  def loop(_, bart):
102
110
  return None, forest_points_per_leaf_distr(bart, X)
111
+
103
112
  _, distr = lax.scan(loop, None, bart)
104
113
  return distr
105
114
 
115
+
106
116
  def check_types(leaf_tree, var_tree, split_tree, max_split):
107
117
  expected_var_dtype = jaxext.minimal_unsigned_dtype(max_split.size - 1)
108
118
  expected_split_dtype = max_split.dtype
109
- return var_tree.dtype == expected_var_dtype and split_tree.dtype == expected_split_dtype
119
+ return (
120
+ var_tree.dtype == expected_var_dtype
121
+ and split_tree.dtype == expected_split_dtype
122
+ )
123
+
110
124
 
111
125
  def check_sizes(leaf_tree, var_tree, split_tree, max_split):
112
126
  return leaf_tree.size == 2 * var_tree.size == 2 * split_tree.size
113
127
 
128
+
114
129
  def check_unused_node(leaf_tree, var_tree, split_tree, max_split):
115
130
  return (var_tree[0] == 0) & (split_tree[0] == 0)
116
131
 
132
+
117
133
  def check_leaf_values(leaf_tree, var_tree, split_tree, max_split):
118
134
  return jnp.all(jnp.isfinite(leaf_tree))
119
135
 
136
+
120
137
  def check_stray_nodes(leaf_tree, var_tree, split_tree, max_split):
121
- index = jnp.arange(2 * split_tree.size, dtype=jaxext.minimal_unsigned_dtype(2 * split_tree.size - 1))
138
+ index = jnp.arange(
139
+ 2 * split_tree.size,
140
+ dtype=jaxext.minimal_unsigned_dtype(2 * split_tree.size - 1),
141
+ )
122
142
  parent_index = index >> 1
123
143
  is_not_leaf = split_tree.at[index].get(mode='fill', fill_value=0) != 0
124
144
  parent_is_leaf = split_tree[parent_index] == 0
@@ -126,6 +146,7 @@ def check_stray_nodes(leaf_tree, var_tree, split_tree, max_split):
126
146
  stray = stray.at[1].set(False)
127
147
  return ~jnp.any(stray)
128
148
 
149
+
129
150
  check_functions = [
130
151
  check_types,
131
152
  check_sizes,
@@ -134,6 +155,7 @@ check_functions = [
134
155
  check_stray_nodes,
135
156
  ]
136
157
 
158
+
137
159
  def check_tree(leaf_tree, var_tree, split_tree, max_split):
138
160
  error_type = jaxext.minimal_unsigned_dtype(2 ** len(check_functions) - 1)
139
161
  error = error_type(0)
@@ -144,15 +166,19 @@ def check_tree(leaf_tree, var_tree, split_tree, max_split):
144
166
  error |= bit
145
167
  return error
146
168
 
169
+
147
170
  def describe_error(error):
148
- return [
149
- func.__name__
150
- for i, func in enumerate(check_functions)
151
- if error & (1 << i)
152
- ]
171
+ return [func.__name__ for i, func in enumerate(check_functions) if error & (1 << i)]
172
+
153
173
 
154
174
  check_forest = jax.vmap(check_tree, in_axes=(0, 0, 0, None))
155
175
 
176
+
156
177
  @functools.partial(jax.vmap, in_axes=(0, None))
157
178
  def check_trace(trace, state):
158
- return check_forest(trace['leaf_trees'], trace['var_trees'], trace['split_trees'], state['max_split'])
179
+ return check_forest(
180
+ trace['leaf_trees'],
181
+ trace['var_trees'],
182
+ trace['split_trees'],
183
+ state.max_split,
184
+ )
bartz/grove.py CHANGED
@@ -1,6 +1,6 @@
1
1
  # bartz/src/bartz/grove.py
2
2
  #
3
- # Copyright (c) 2024, Giacomo Petrillo
3
+ # Copyright (c) 2024-2025, Giacomo Petrillo
4
4
  #
5
5
  # This file is part of bartz.
6
6
  #
@@ -34,7 +34,7 @@ The 'leaf' array contains the values in the leaves.
34
34
 
35
35
  The 'var' array contains the axes along which the decision nodes operate.
36
36
 
37
- The 'split' array contains the decision boundaries. The boundaries are open on the right, i.e., a point belongs to the left child iff x < split. Whether a node is a leaf is indicated by the corresponding 'split' element being 0.
37
+ The 'split' array contains the decision boundaries. The boundaries are open on the right, i.e., a point belongs to the left child iff x < split. Whether a node is a leaf is indicated by the corresponding 'split' element being 0. Unused nodes also have split set to 0.
38
38
 
39
39
  Since the nodes at the bottom can only be leaves and not decision nodes, the 'var' and 'split' arrays have half the length of the 'leaf' array.
40
40
 
@@ -44,11 +44,12 @@ import functools
44
44
  import math
45
45
 
46
46
  import jax
47
- from jax import numpy as jnp
48
47
  from jax import lax
48
+ from jax import numpy as jnp
49
49
 
50
50
  from . import jaxext
51
51
 
52
+
52
53
  def make_tree(depth, dtype):
53
54
  """
54
55
  Make an array to represent a binary tree.
@@ -66,7 +67,8 @@ def make_tree(depth, dtype):
66
67
  tree : array
67
68
  An array of zeroes with shape (2 ** depth,).
68
69
  """
69
- return jnp.zeros(2 ** depth, dtype)
70
+ return jnp.zeros(2**depth, dtype)
71
+
70
72
 
71
73
  def tree_depth(tree):
72
74
  """
@@ -85,6 +87,7 @@ def tree_depth(tree):
85
87
  """
86
88
  return int(round(math.log2(tree.shape[-1])))
87
89
 
90
+
88
91
  def traverse_tree(x, var_tree, split_tree):
89
92
  """
90
93
  Find the leaf where a point falls into.
@@ -103,7 +106,6 @@ def traverse_tree(x, var_tree, split_tree):
103
106
  index : int
104
107
  The index of the leaf.
105
108
  """
106
-
107
109
  carry = (
108
110
  jnp.zeros((), bool),
109
111
  jnp.ones((), jaxext.minimal_unsigned_dtype(2 * var_tree.size - 1)),
@@ -125,6 +127,7 @@ def traverse_tree(x, var_tree, split_tree):
125
127
  (_, index), _ = lax.scan(loop, carry, None, depth, unroll=16)
126
128
  return index
127
129
 
130
+
128
131
  @functools.partial(jaxext.vmap_nodoc, in_axes=(None, 0, 0))
129
132
  @functools.partial(jaxext.vmap_nodoc, in_axes=(1, None, None))
130
133
  def traverse_forest(X, var_trees, split_trees):
@@ -147,6 +150,7 @@ def traverse_forest(X, var_trees, split_trees):
147
150
  """
148
151
  return traverse_tree(X, var_trees, split_trees)
149
152
 
153
+
150
154
  def evaluate_forest(X, leaf_trees, var_trees, split_trees, dtype=None, sum_trees=True):
151
155
  """
152
156
  Evaluate a ensemble of trees at an array of points.
@@ -178,11 +182,12 @@ def evaluate_forest(X, leaf_trees, var_trees, split_trees, dtype=None, sum_trees
178
182
  leaves = leaf_trees[tree_index[:, None], indices]
179
183
  if sum_trees:
180
184
  return jnp.sum(leaves, axis=0, dtype=dtype)
181
- # this sum suggests to swap the vmaps, but I think it's better for X
182
- # copying to keep it that way
185
+ # this sum suggests to swap the vmaps, but I think it's better for X
186
+ # copying to keep it that way
183
187
  else:
184
188
  return leaves
185
189
 
190
+
186
191
  def is_actual_leaf(split_tree, *, add_bottom_level=False):
187
192
  """
188
193
  Return a mask indicating the leaf nodes in a tree.
@@ -211,6 +216,7 @@ def is_actual_leaf(split_tree, *, add_bottom_level=False):
211
216
  parent_nonleaf = parent_nonleaf.at[1].set(True)
212
217
  return is_leaf & parent_nonleaf
213
218
 
219
+
214
220
  def is_leaves_parent(split_tree):
215
221
  """
216
222
  Return a mask indicating the nodes with leaf (and only leaf) children.
@@ -225,14 +231,17 @@ def is_leaves_parent(split_tree):
225
231
  is_leaves_parent : bool array (2 ** (d - 1),)
226
232
  The mask indicating which nodes have leaf children.
227
233
  """
228
- index = jnp.arange(split_tree.size, dtype=jaxext.minimal_unsigned_dtype(2 * split_tree.size - 1))
229
- left_index = index << 1 # left child
230
- right_index = left_index + 1 # right child
234
+ index = jnp.arange(
235
+ split_tree.size, dtype=jaxext.minimal_unsigned_dtype(2 * split_tree.size - 1)
236
+ )
237
+ left_index = index << 1 # left child
238
+ right_index = left_index + 1 # right child
231
239
  left_leaf = split_tree.at[left_index].get(mode='fill', fill_value=0) == 0
232
240
  right_leaf = split_tree.at[right_index].get(mode='fill', fill_value=0) == 0
233
241
  is_not_leaf = split_tree.astype(bool)
234
242
  return is_not_leaf & left_leaf & right_leaf
235
- # the 0-th item has split == 0, so it's not counted
243
+ # the 0-th item has split == 0, so it's not counted
244
+
236
245
 
237
246
  def tree_depths(tree_length):
238
247
  """
@@ -253,8 +262,49 @@ def tree_depths(tree_length):
253
262
  depths = []
254
263
  depth = 0
255
264
  for i in range(tree_length):
256
- if i == 2 ** depth:
265
+ if i == 2**depth:
257
266
  depth += 1
258
267
  depths.append(depth - 1)
259
268
  depths[0] = 0
260
269
  return jnp.array(depths, jaxext.minimal_unsigned_dtype(max(depths)))
270
+
271
+
272
+ def is_used(split_tree):
273
+ """
274
+ Return a mask indicating the used nodes in a tree.
275
+
276
+ Parameters
277
+ ----------
278
+ split_tree : int array (2 ** (d - 1),)
279
+ The decision boundaries of the tree.
280
+
281
+ Returns
282
+ -------
283
+ is_used : bool array (2 ** d,)
284
+ A mask indicating which nodes are actually used.
285
+ """
286
+ internal_node = split_tree.astype(bool)
287
+ internal_node = jnp.concatenate([internal_node, jnp.zeros_like(internal_node)])
288
+ actual_leaf = is_actual_leaf(split_tree, add_bottom_level=True)
289
+ return internal_node | actual_leaf
290
+
291
+
292
+ def forest_fill(split_trees):
293
+ """
294
+ Return the fraction of used nodes in a set of trees.
295
+
296
+ Parameters
297
+ ----------
298
+ split_trees : array (m, 2 ** (d - 1),)
299
+ The decision boundaries of the trees.
300
+
301
+ Returns
302
+ -------
303
+ fill : float
304
+ The number of tree nodes in the forest over the maximum number that
305
+ could be stored in the arrays.
306
+ """
307
+ m, _ = split_trees.shape
308
+ used = jax.vmap(is_used)(split_trees)
309
+ count = jnp.count_nonzero(used)
310
+ return count / (used.size - m)
bartz/jaxext.py CHANGED
@@ -22,60 +22,74 @@
22
22
  # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23
23
  # SOFTWARE.
24
24
 
25
+ """Additions to jax."""
26
+
25
27
  import functools
26
28
  import math
27
29
  import warnings
28
30
 
29
- from scipy import special
30
31
  import jax
32
+ from jax import lax, random, tree_util
31
33
  from jax import numpy as jnp
32
- from jax import tree_util
33
- from jax import lax
34
+ from scipy import special
35
+
34
36
 
35
37
  def float_type(*args):
36
- """
37
- Determine the jax floating point result type given operands/types.
38
- """
38
+ """Determine the jax floating point result type given operands/types."""
39
39
  t = jnp.result_type(*args)
40
40
  return jnp.sin(jnp.empty(0, t)).dtype
41
41
 
42
- def castto(func, type):
42
+
43
+ def _castto(func, type):
43
44
  @functools.wraps(func)
44
45
  def newfunc(*args, **kw):
45
46
  return func(*args, **kw).astype(type)
47
+
46
48
  return newfunc
47
49
 
50
+
48
51
  class scipy:
52
+ """Mockup of the :external:py:mod:`scipy` module."""
49
53
 
50
54
  class special:
55
+ """Mockup of the :external:py:mod:`scipy.special` module."""
51
56
 
52
- @functools.wraps(special.gammainccinv)
57
+ @staticmethod
53
58
  def gammainccinv(a, y):
59
+ """Survival function inverse of the Gamma(a, 1) distribution."""
54
60
  a = jnp.asarray(a)
55
61
  y = jnp.asarray(y)
56
62
  shape = jnp.broadcast_shapes(a.shape, y.shape)
57
63
  dtype = float_type(a.dtype, y.dtype)
58
64
  dummy = jax.ShapeDtypeStruct(shape, dtype)
59
- ufunc = castto(special.gammainccinv, dtype)
65
+ ufunc = _castto(special.gammainccinv, dtype)
60
66
  return jax.pure_callback(ufunc, dummy, a, y, vmap_method='expand_dims')
61
67
 
62
68
  class stats:
69
+ """Mockup of the :external:py:mod:`scipy.stats` module."""
63
70
 
64
71
  class invgamma:
72
+ """Class that represents the distribution InvGamma(a, 1)."""
65
73
 
74
+ @staticmethod
66
75
  def ppf(q, a):
76
+ """Percentile point function."""
67
77
  return 1 / scipy.special.gammainccinv(a, q)
68
78
 
69
- @functools.wraps(jax.vmap)
79
+
70
80
  def vmap_nodoc(fun, *args, **kw):
71
81
  """
72
- Version of `jax.vmap` that preserves the docstring of the input function.
82
+ Acts like `jax.vmap` but preserves the docstring of the function unchanged.
83
+
84
+ This is useful if the docstring already takes into account that the
85
+ arguments have additional axes due to vmap.
73
86
  """
74
87
  doc = fun.__doc__
75
88
  fun = jax.vmap(fun, *args, **kw)
76
89
  fun.__doc__ = doc
77
90
  return fun
78
91
 
92
+
79
93
  def huge_value(x):
80
94
  """
81
95
  Return the maximum value that can be stored in `x`.
@@ -95,23 +109,23 @@ def huge_value(x):
95
109
  else:
96
110
  return jnp.inf
97
111
 
98
- def minimal_unsigned_dtype(max_value):
99
- """
100
- Return the smallest unsigned integer dtype that can represent a given
101
- maximum value (inclusive).
102
- """
103
- if max_value < 2 ** 8:
112
+
113
+ def minimal_unsigned_dtype(value):
114
+ """Return the smallest unsigned integer dtype that can represent `value`."""
115
+ if value < 2**8:
104
116
  return jnp.uint8
105
- if max_value < 2 ** 16:
117
+ if value < 2**16:
106
118
  return jnp.uint16
107
- if max_value < 2 ** 32:
119
+ if value < 2**32:
108
120
  return jnp.uint32
109
121
  return jnp.uint64
110
122
 
123
+
111
124
  def signed_to_unsigned(int_dtype):
112
125
  """
113
- Map a signed integer type to its unsigned counterpart. Unsigned types are
114
- passed through.
126
+ Map a signed integer type to its unsigned counterpart.
127
+
128
+ Unsigned types are passed through.
115
129
  """
116
130
  assert jnp.issubdtype(int_dtype, jnp.integer)
117
131
  if jnp.issubdtype(int_dtype, jnp.unsignedinteger):
@@ -125,12 +139,12 @@ def signed_to_unsigned(int_dtype):
125
139
  if int_dtype == jnp.int64:
126
140
  return jnp.uint64
127
141
 
142
+
128
143
  def ensure_unsigned(x):
129
- """
130
- If x has signed integer type, cast it to the unsigned dtype of the same size.
131
- """
144
+ """If x has signed integer type, cast it to the unsigned dtype of the same size."""
132
145
  return x.astype(signed_to_unsigned(x.dtype))
133
146
 
147
+
134
148
  @functools.partial(jax.jit, static_argnums=(1,))
135
149
  def unique(x, size, fill_value):
136
150
  """
@@ -158,15 +172,18 @@ def unique(x, size, fill_value):
158
172
  if size == 0:
159
173
  return jnp.empty(0, x.dtype), 0
160
174
  x = jnp.sort(x)
175
+
161
176
  def loop(carry, x):
162
177
  i_out, i_in, last, out = carry
163
178
  i_out = jnp.where(x == last, i_out, i_out + 1)
164
179
  out = out.at[i_out].set(x)
165
180
  return (i_out, i_in + 1, x, out), None
181
+
166
182
  carry = 0, 0, x[0], jnp.full(size, fill_value, x.dtype)
167
183
  (actual_length, _, _, out), _ = jax.lax.scan(loop, carry, x[:size])
168
184
  return out, actual_length + 1
169
185
 
186
+
170
187
  def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False):
171
188
  """
172
189
  Batch a function such that each batch is smaller than a threshold.
@@ -203,6 +220,7 @@ def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False)
203
220
  def check_no_nones(axes, tree):
204
221
  def check_not_none(_, axis):
205
222
  assert axis is not None
223
+
206
224
  tree_util.tree_map(check_not_none, tree, axes)
207
225
 
208
226
  def extract_size(axes, tree):
@@ -211,6 +229,7 @@ def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False)
211
229
  return None
212
230
  else:
213
231
  return x.shape[axis]
232
+
214
233
  sizes = tree_util.tree_map(get_size, tree, axes)
215
234
  sizes, _ = tree_util.tree_flatten(sizes)
216
235
  assert all(s == sizes[0] for s in sizes)
@@ -219,6 +238,7 @@ def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False)
219
238
  def sum_nbytes(tree):
220
239
  def nbytes(x):
221
240
  return math.prod(x.shape) * x.dtype.itemsize
241
+
222
242
  return tree_util.tree_reduce(lambda size, x: size + nbytes(x), tree, 0)
223
243
 
224
244
  def next_divisor_small(dividend, min_divisor):
@@ -247,6 +267,7 @@ def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False)
247
267
  return None
248
268
  else:
249
269
  return x
270
+
250
271
  return tree_util.tree_map(pull_nonbatched, tree, axes), tree
251
272
 
252
273
  def push_nonbatched(axes, tree, original_tree):
@@ -255,32 +276,38 @@ def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False)
255
276
  return original_x
256
277
  else:
257
278
  return x
279
+
258
280
  return tree_util.tree_map(push_nonbatched, original_tree, tree, axes)
259
281
 
260
282
  def move_axes_out(axes, tree):
261
283
  def move_axis_out(x, axis):
262
284
  return jnp.moveaxis(x, axis, 0)
285
+
263
286
  return tree_util.tree_map(move_axis_out, tree, axes)
264
287
 
265
288
  def move_axes_in(axes, tree):
266
289
  def move_axis_in(x, axis):
267
290
  return jnp.moveaxis(x, 0, axis)
291
+
268
292
  return tree_util.tree_map(move_axis_in, tree, axes)
269
293
 
270
294
  def batch(tree, nbatches):
271
295
  def batch(x):
272
296
  return x.reshape((nbatches, x.shape[0] // nbatches) + x.shape[1:])
297
+
273
298
  return tree_util.tree_map(batch, tree)
274
299
 
275
300
  def unbatch(tree):
276
301
  def unbatch(x):
277
302
  return x.reshape((x.shape[0] * x.shape[1],) + x.shape[2:])
303
+
278
304
  return tree_util.tree_map(unbatch, tree)
279
305
 
280
306
  def check_same(tree1, tree2):
281
307
  def check_same(x1, x2):
282
308
  assert x1.shape == x2.shape
283
309
  assert x1.dtype == x2.dtype
310
+
284
311
  tree_util.tree_map(check_same, tree1, tree2)
285
312
 
286
313
  initial_in_axes = in_axes
@@ -300,7 +327,9 @@ def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False)
300
327
  args, nonbatched_args = pull_nonbatched(in_axes, args)
301
328
 
302
329
  total_nbytes = sum_nbytes((args, example_result))
303
- min_nbatches = total_nbytes // max_io_nbytes + bool(total_nbytes % max_io_nbytes)
330
+ min_nbatches = total_nbytes // max_io_nbytes + bool(
331
+ total_nbytes % max_io_nbytes
332
+ )
304
333
  min_nbatches = max(1, min_nbatches)
305
334
  nbatches = next_divisor(size, min_nbatches)
306
335
  assert 1 <= nbatches <= max(1, size)
@@ -310,7 +339,9 @@ def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False)
310
339
  batch_nbytes = total_nbytes // nbatches
311
340
  if batch_nbytes > max_io_nbytes:
312
341
  assert size == nbatches
313
- warnings.warn(f'batch_nbytes = {batch_nbytes} > max_io_nbytes = {max_io_nbytes}')
342
+ warnings.warn(
343
+ f'batch_nbytes = {batch_nbytes} > max_io_nbytes = {max_io_nbytes}'
344
+ )
314
345
 
315
346
  def loop(_, args):
316
347
  args = move_axes_in(in_axes, args)
@@ -333,17 +364,60 @@ def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False)
333
364
 
334
365
  return batched_func
335
366
 
336
- @tree_util.register_pytree_node_class
337
- class LeafDict(dict):
338
- """ dictionary that acts as a leaf in jax pytrees, to store compile-time
339
- values """
340
367
 
341
- def tree_flatten(self):
342
- return (), self
368
+ class split:
369
+ """
370
+ Split a key into `num` keys.
343
371
 
344
- @classmethod
345
- def tree_unflatten(cls, aux_data, children):
346
- return aux_data
372
+ Parameters
373
+ ----------
374
+ key : jax.dtypes.prng_key array
375
+ The key to split.
376
+ num : int
377
+ The number of keys to split into.
378
+ """
347
379
 
348
- def __repr__(self):
349
- return f'{__class__.__name__}({super().__repr__()})'
380
+ def __init__(self, key, num=2):
381
+ self._keys = random.split(key, num)
382
+
383
+ def __len__(self):
384
+ return self._keys.size
385
+
386
+ def pop(self, shape=None):
387
+ """
388
+ Pop one or more keys from the list.
389
+
390
+ Parameters
391
+ ----------
392
+ shape : int or tuple of int, optional
393
+ The shape of the keys to pop. If `None`, a single key is popped.
394
+ If an integer, that many keys are popped. If a tuple, the keys are
395
+ reshaped to that shape.
396
+
397
+ Returns
398
+ -------
399
+ keys : jax.dtypes.prng_key array
400
+ The popped keys.
401
+
402
+ Raises
403
+ ------
404
+ IndexError
405
+ If `shape` is larger than the number of keys left in the list.
406
+
407
+ Notes
408
+ -----
409
+ The keys are popped from the beginning of the list, so for example
410
+ ``list(keys.pop(2))`` is equivalent to ``[keys.pop(), keys.pop()]``.
411
+ """
412
+ if shape is None:
413
+ shape = ()
414
+ elif not isinstance(shape, tuple):
415
+ shape = (shape,)
416
+ size_to_pop = math.prod(shape)
417
+ if size_to_pop > self._keys.size:
418
+ raise IndexError(
419
+ f'Cannot pop {size_to_pop} keys from {self._keys.size} keys'
420
+ )
421
+ popped_keys = self._keys[:size_to_pop]
422
+ self._keys = self._keys[size_to_pop:]
423
+ return popped_keys.reshape(shape)