bartz 0.0.1__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bartz/jaxext.py CHANGED
@@ -10,10 +10,10 @@
10
10
  # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11
11
  # copies of the Software, and to permit persons to whom the Software is
12
12
  # furnished to do so, subject to the following conditions:
13
- #
13
+ #
14
14
  # The above copyright notice and this permission notice shall be included in all
15
15
  # copies or substantial portions of the Software.
16
- #
16
+ #
17
17
  # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18
18
  # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19
19
  # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
@@ -23,12 +23,19 @@
23
23
  # SOFTWARE.
24
24
 
25
25
  import functools
26
+ import math
27
+ import warnings
26
28
 
27
29
  from scipy import special
28
30
  import jax
29
31
  from jax import numpy as jnp
32
+ from jax import tree_util
33
+ from jax import lax
30
34
 
31
35
  def float_type(*args):
36
+ """
37
+ Determine the jax floating point result type given operands/types.
38
+ """
32
39
  t = jnp.result_type(*args)
33
40
  return jnp.sin(jnp.empty(0, t)).dtype
34
41
 
@@ -39,8 +46,8 @@ def castto(func, type):
39
46
  return newfunc
40
47
 
41
48
  def pure_callback_ufunc(callback, dtype, *args, excluded=None, **kwargs):
42
- """ version of jax.pure_callback that deals correctly with ufuncs,
43
- see https://github.com/google/jax/issues/17187 """
49
+ """ version of `jax.pure_callback` that deals correctly with ufuncs,
50
+ see `<https://github.com/google/jax/issues/17187>`_ """
44
51
  if excluded is None:
45
52
  excluded = ()
46
53
  shape = jnp.broadcast_shapes(*(
@@ -63,6 +70,7 @@ class scipy:
63
70
 
64
71
  class special:
65
72
 
73
+ @functools.wraps(special.gammainccinv)
66
74
  def gammainccinv(a, y):
67
75
  a = jnp.asarray(a)
68
76
  y = jnp.asarray(y)
@@ -73,13 +81,261 @@ class scipy:
73
81
  class stats:
74
82
 
75
83
  class invgamma:
76
-
84
+
77
85
  def ppf(q, a):
78
86
  return 1 / scipy.special.gammainccinv(a, q)
79
87
 
80
88
  @functools.wraps(jax.vmap)
81
89
  def vmap_nodoc(fun, *args, **kw):
90
+ """
91
+ Version of `jax.vmap` that preserves the docstring of the input function.
92
+ """
82
93
  doc = fun.__doc__
83
94
  fun = jax.vmap(fun, *args, **kw)
84
95
  fun.__doc__ = doc
85
96
  return fun
97
+
98
+ def huge_value(x):
99
+ """
100
+ Return the maximum value that can be stored in `x`.
101
+
102
+ Parameters
103
+ ----------
104
+ x : array
105
+ A numerical numpy or jax array.
106
+
107
+ Returns
108
+ -------
109
+ maxval : scalar
110
+ The maximum value allowed by `x`'s type (+inf for floats).
111
+ """
112
+ if jnp.issubdtype(x.dtype, jnp.integer):
113
+ return jnp.iinfo(x.dtype).max
114
+ else:
115
+ return jnp.inf
116
+
117
+ def minimal_unsigned_dtype(max_value):
118
+ """
119
+ Return the smallest unsigned integer dtype that can represent a given
120
+ maximum value (inclusive).
121
+ """
122
+ if max_value < 2 ** 8:
123
+ return jnp.uint8
124
+ if max_value < 2 ** 16:
125
+ return jnp.uint16
126
+ if max_value < 2 ** 32:
127
+ return jnp.uint32
128
+ return jnp.uint64
129
+
130
+ def signed_to_unsigned(int_dtype):
131
+ """
132
+ Map a signed integer type to its unsigned counterpart. Unsigned types are
133
+ passed through.
134
+ """
135
+ assert jnp.issubdtype(int_dtype, jnp.integer)
136
+ if jnp.issubdtype(int_dtype, jnp.unsignedinteger):
137
+ return int_dtype
138
+ if int_dtype == jnp.int8:
139
+ return jnp.uint8
140
+ if int_dtype == jnp.int16:
141
+ return jnp.uint16
142
+ if int_dtype == jnp.int32:
143
+ return jnp.uint32
144
+ if int_dtype == jnp.int64:
145
+ return jnp.uint64
146
+
147
+ def ensure_unsigned(x):
148
+ """
149
+ If x has signed integer type, cast it to the unsigned dtype of the same size.
150
+ """
151
+ return x.astype(signed_to_unsigned(x.dtype))
152
+
153
+ @functools.partial(jax.jit, static_argnums=(1,))
154
+ def unique(x, size, fill_value):
155
+ """
156
+ Restricted version of `jax.numpy.unique` that uses less memory.
157
+
158
+ Parameters
159
+ ----------
160
+ x : 1d array
161
+ The input array.
162
+ size : int
163
+ The length of the output.
164
+ fill_value : scalar
165
+ The value to fill the output with if `size` is greater than the number
166
+ of unique values in `x`.
167
+
168
+ Returns
169
+ -------
170
+ out : array (size,)
171
+ The unique values in `x`, sorted, and right-padded with `fill_value`.
172
+ actual_length : int
173
+ The number of used values in `out`.
174
+ """
175
+ if x.size == 0:
176
+ return jnp.full(size, fill_value, x.dtype), 0
177
+ if size == 0:
178
+ return jnp.empty(0, x.dtype), 0
179
+ x = jnp.sort(x)
180
+ def loop(carry, x):
181
+ i_out, i_in, last, out = carry
182
+ i_out = jnp.where(x == last, i_out, i_out + 1)
183
+ out = out.at[i_out].set(x)
184
+ return (i_out, i_in + 1, x, out), None
185
+ carry = 0, 0, x[0], jnp.full(size, fill_value, x.dtype)
186
+ (actual_length, _, _, out), _ = jax.lax.scan(loop, carry, x[:size])
187
+ return out, actual_length + 1
188
+
189
+ def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False):
190
+ """
191
+ Batch a function such that each batch is smaller than a threshold.
192
+
193
+ Parameters
194
+ ----------
195
+ func : callable
196
+ A jittable function with positional arguments only, with inputs and
197
+ outputs pytrees of arrays.
198
+ max_io_nbytes : int
199
+ The maximum number of input + output bytes in each batch.
200
+ in_axes : pytree of ints, default 0
201
+ A tree matching the structure of the function input, indicating along
202
+ which axes each array should be batched. If a single integer, it is
203
+ used for all arrays.
204
+ out_axes : pytree of ints, default 0
205
+ The same for outputs.
206
+ return_nbatches : bool, default False
207
+ If True, the number of batches is returned as a second output.
208
+
209
+ Returns
210
+ -------
211
+ batched_func : callable
212
+ A function with the same signature as `func`, but that processes the
213
+ input and output in batches in a loop.
214
+ """
215
+
216
+ def expand_axes(axes, tree):
217
+ if isinstance(axes, int):
218
+ return tree_util.tree_map(lambda _: axes, tree)
219
+ return tree_util.tree_map(lambda _, axis: axis, tree, axes)
220
+
221
+ def extract_size(axes, tree):
222
+ sizes = tree_util.tree_map(lambda x, axis: x.shape[axis], tree, axes)
223
+ sizes, _ = tree_util.tree_flatten(sizes)
224
+ assert all(s == sizes[0] for s in sizes)
225
+ return sizes[0]
226
+
227
+ def sum_nbytes(tree):
228
+ def nbytes(x):
229
+ return math.prod(x.shape) * x.dtype.itemsize
230
+ return tree_util.tree_reduce(lambda size, x: size + nbytes(x), tree, 0)
231
+
232
+ def next_divisor_small(dividend, min_divisor):
233
+ for divisor in range(min_divisor, int(math.sqrt(dividend)) + 1):
234
+ if dividend % divisor == 0:
235
+ return divisor
236
+ return dividend
237
+
238
+ def next_divisor_large(dividend, min_divisor):
239
+ max_inv_divisor = dividend // min_divisor
240
+ for inv_divisor in range(max_inv_divisor, 0, -1):
241
+ if dividend % inv_divisor == 0:
242
+ return dividend // inv_divisor
243
+ return dividend
244
+
245
+ def next_divisor(dividend, min_divisor):
246
+ if min_divisor * min_divisor <= dividend:
247
+ return next_divisor_small(dividend, min_divisor)
248
+ return next_divisor_large(dividend, min_divisor)
249
+
250
+ def move_axes_out(axes, tree):
251
+ def move_axis_out(axis, x):
252
+ if axis != 0:
253
+ return jnp.moveaxis(x, axis, 0)
254
+ return x
255
+ return tree_util.tree_map(move_axis_out, axes, tree)
256
+
257
+ def move_axes_in(axes, tree):
258
+ def move_axis_in(axis, x):
259
+ if axis != 0:
260
+ return jnp.moveaxis(x, 0, axis)
261
+ return x
262
+ return tree_util.tree_map(move_axis_in, axes, tree)
263
+
264
+ def batch(tree, nbatches):
265
+ def batch(x):
266
+ return x.reshape((nbatches, x.shape[0] // nbatches) + x.shape[1:])
267
+ return tree_util.tree_map(batch, tree)
268
+
269
+ def unbatch(tree):
270
+ def unbatch(x):
271
+ return x.reshape((x.shape[0] * x.shape[1],) + x.shape[2:])
272
+ return tree_util.tree_map(unbatch, tree)
273
+
274
+ def check_same(tree1, tree2):
275
+ def check_same(x1, x2):
276
+ assert x1.shape == x2.shape
277
+ assert x1.dtype == x2.dtype
278
+ tree_util.tree_map(check_same, tree1, tree2)
279
+
280
+ initial_in_axes = in_axes
281
+ initial_out_axes = out_axes
282
+
283
+ @jax.jit
284
+ @functools.wraps(func)
285
+ def batched_func(*args):
286
+ example_result = jax.eval_shape(func, *args)
287
+
288
+ in_axes = expand_axes(initial_in_axes, args)
289
+ out_axes = expand_axes(initial_out_axes, example_result)
290
+
291
+ in_size = extract_size(in_axes, args)
292
+ out_size = extract_size(out_axes, example_result)
293
+ assert in_size == out_size
294
+ size = in_size
295
+
296
+ total_nbytes = sum_nbytes(args) + sum_nbytes(example_result)
297
+ min_nbatches = total_nbytes // max_io_nbytes + bool(total_nbytes % max_io_nbytes)
298
+ nbatches = next_divisor(size, min_nbatches)
299
+ assert 1 <= nbatches <= size
300
+ assert size % nbatches == 0
301
+ assert total_nbytes % nbatches == 0
302
+
303
+ batch_nbytes = total_nbytes // nbatches
304
+ if batch_nbytes > max_io_nbytes:
305
+ assert size == nbatches
306
+ warnings.warn(f'batch_nbytes = {batch_nbytes} > max_io_nbytes = {max_io_nbytes}')
307
+
308
+ def loop(_, args):
309
+ args = move_axes_in(in_axes, args)
310
+ result = func(*args)
311
+ result = move_axes_out(out_axes, result)
312
+ return None, result
313
+
314
+ args = move_axes_out(in_axes, args)
315
+ args = batch(args, nbatches)
316
+ _, result = lax.scan(loop, None, args)
317
+ result = unbatch(result)
318
+ result = move_axes_in(out_axes, result)
319
+
320
+ check_same(example_result, result)
321
+
322
+ if return_nbatches:
323
+ return result, nbatches
324
+ return result
325
+
326
+ return batched_func
327
+
328
+ @tree_util.register_pytree_node_class
329
+ class LeafDict(dict):
330
+ """ dictionary that acts as a leaf in jax pytrees, to store compile-time
331
+ values """
332
+
333
+ def tree_flatten(self):
334
+ return (), self
335
+
336
+ @classmethod
337
+ def tree_unflatten(cls, aux_data, children):
338
+ return aux_data
339
+
340
+ def __repr__(self):
341
+ return f'{__class__.__name__}({super().__repr__()})'
bartz/mcmcloop.py CHANGED
@@ -52,7 +52,7 @@ def run_mcmc(bart, n_burn, n_save, n_skip, callback, key):
52
52
  n_save : int
53
53
  The number of iterations to save.
54
54
  n_skip : int
55
- The number of iterations to skip between each saved iteration.
55
+ The number of iterations to skip between each saved iteration, plus 1.
56
56
  callback : callable
57
57
  An arbitrary function run at each iteration, called with the following
58
58
  arguments, passed by keyword:
@@ -100,15 +100,24 @@ def run_mcmc(bart, n_burn, n_save, n_skip, callback, key):
100
100
  def inner_loop(carry, _, tracelist, burnin):
101
101
  bart, i_total, i_skip, key = carry
102
102
  key, subkey = random.split(key)
103
- bart = mcmcstep.mcmc_step(bart, subkey)
103
+ bart = mcmcstep.step(bart, subkey)
104
104
  callback(bart=bart, burnin=burnin, i_total=i_total, i_skip=i_skip, **callback_kw)
105
105
  output = {key: bart[key] for key in tracelist}
106
106
  return (bart, i_total + 1, i_skip + 1, key), output
107
107
 
108
- # TODO avoid invoking this altogether if burnin is 0 to shorten compilation time & size
109
- carry = bart, 0, 0, key
110
- burnin_loop = functools.partial(inner_loop, tracelist=tracelist_burnin, burnin=True)
111
- (bart, i_total, _, key), burnin_trace = lax.scan(burnin_loop, carry, None, n_burn)
108
+ def empty_trace(bart, tracelist):
109
+ return {
110
+ key: jnp.empty((0,) + bart[key].shape, bart[key].dtype)
111
+ for key in tracelist
112
+ }
113
+
114
+ if n_burn > 0:
115
+ carry = bart, 0, 0, key
116
+ burnin_loop = functools.partial(inner_loop, tracelist=tracelist_burnin, burnin=True)
117
+ (bart, i_total, _, key), burnin_trace = lax.scan(burnin_loop, carry, None, n_burn)
118
+ else:
119
+ i_total = 0
120
+ burnin_trace = empty_trace(bart, tracelist_burnin)
112
121
 
113
122
  def outer_loop(carry, _):
114
123
  bart, i_total, key = carry
@@ -118,8 +127,11 @@ def run_mcmc(bart, n_burn, n_save, n_skip, callback, key):
118
127
  output = {key: bart[key] for key in tracelist_main}
119
128
  return (bart, i_total, key), output
120
129
 
121
- carry = bart, i_total, key
122
- (bart, _, _), main_trace = lax.scan(outer_loop, carry, None, n_save)
130
+ if n_save > 0:
131
+ carry = bart, i_total, key
132
+ (bart, _, _), main_trace = lax.scan(outer_loop, carry, None, n_save)
133
+ else:
134
+ main_trace = empty_trace(bart, tracelist_main)
123
135
 
124
136
  return bart, burnin_trace, main_trace
125
137
 
@@ -127,7 +139,8 @@ def run_mcmc(bart, n_burn, n_save, n_skip, callback, key):
127
139
 
128
140
  @functools.lru_cache
129
141
  # cache to make the callback function object unique, such that the jit
130
- # of run_mcmc recognizes it
142
+ # of run_mcmc recognizes it => with the callback state, I can make
143
+ # printevery a runtime quantity
131
144
  def make_simple_print_callback(printevery):
132
145
  """
133
146
  Create a logging callback function for MCMC iterations.
@@ -149,11 +162,12 @@ def make_simple_print_callback(printevery):
149
162
  grow_acc = bart['grow_acc_count'] / bart['grow_prop_count']
150
163
  prune_acc = bart['prune_acc_count'] / bart['prune_prop_count']
151
164
  n_total = n_burn + n_save * n_skip
152
- debug.callback(simple_print_callback_impl, burnin, i_total, n_total, grow_prop, grow_acc, prune_prop, prune_acc, printevery)
165
+ printcond = (i_total + 1) % printevery == 0
166
+ debug.callback(_simple_print_callback, burnin, i_total, n_total, grow_prop, grow_acc, prune_prop, prune_acc, printcond)
153
167
  return callback
154
168
 
155
- def simple_print_callback_impl(burnin, i_total, n_total, grow_prop, grow_acc, prune_prop, prune_acc, printevery):
156
- if (i_total + 1) % printevery == 0:
169
+ def _simple_print_callback(burnin, i_total, n_total, grow_prop, grow_acc, prune_prop, prune_acc, printcond):
170
+ if printcond:
157
171
  burnin_flag = ' (burnin)' if burnin else ''
158
172
  total_str = str(n_total)
159
173
  ndigits = len(total_str)
@@ -180,6 +194,6 @@ def evaluate_trace(trace, X):
180
194
  The predictions for each iteration of the MCMC.
181
195
  """
182
196
  def loop(_, state):
183
- return None, grove.evaluate_tree_vmap_x(X, state['leaf_trees'], state['var_trees'], state['split_trees'], jnp.float32)
197
+ return None, grove.evaluate_forest(X, state['leaf_trees'], state['var_trees'], state['split_trees'], jnp.float32)
184
198
  _, y = lax.scan(loop, None, trace)
185
199
  return y