bartz 0.0.1__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bartz/{interface.py → BART.py} +10 -18
- bartz/__init__.py +7 -2
- bartz/_version.py +1 -0
- bartz/debug.py +9 -22
- bartz/grove.py +73 -120
- bartz/jaxext.py +261 -5
- bartz/mcmcloop.py +27 -13
- bartz/mcmcstep.py +510 -439
- bartz/prepcovars.py +25 -30
- {bartz-0.0.1.dist-info → bartz-0.2.0.dist-info}/METADATA +7 -1
- bartz-0.2.0.dist-info/RECORD +13 -0
- bartz-0.0.1.dist-info/RECORD +0 -12
- {bartz-0.0.1.dist-info → bartz-0.2.0.dist-info}/LICENSE +0 -0
- {bartz-0.0.1.dist-info → bartz-0.2.0.dist-info}/WHEEL +0 -0
bartz/jaxext.py
CHANGED
|
@@ -10,10 +10,10 @@
|
|
|
10
10
|
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
11
11
|
# copies of the Software, and to permit persons to whom the Software is
|
|
12
12
|
# furnished to do so, subject to the following conditions:
|
|
13
|
-
#
|
|
13
|
+
#
|
|
14
14
|
# The above copyright notice and this permission notice shall be included in all
|
|
15
15
|
# copies or substantial portions of the Software.
|
|
16
|
-
#
|
|
16
|
+
#
|
|
17
17
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
18
18
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
19
19
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
@@ -23,12 +23,19 @@
|
|
|
23
23
|
# SOFTWARE.
|
|
24
24
|
|
|
25
25
|
import functools
|
|
26
|
+
import math
|
|
27
|
+
import warnings
|
|
26
28
|
|
|
27
29
|
from scipy import special
|
|
28
30
|
import jax
|
|
29
31
|
from jax import numpy as jnp
|
|
32
|
+
from jax import tree_util
|
|
33
|
+
from jax import lax
|
|
30
34
|
|
|
31
35
|
def float_type(*args):
|
|
36
|
+
"""
|
|
37
|
+
Determine the jax floating point result type given operands/types.
|
|
38
|
+
"""
|
|
32
39
|
t = jnp.result_type(*args)
|
|
33
40
|
return jnp.sin(jnp.empty(0, t)).dtype
|
|
34
41
|
|
|
@@ -39,8 +46,8 @@ def castto(func, type):
|
|
|
39
46
|
return newfunc
|
|
40
47
|
|
|
41
48
|
def pure_callback_ufunc(callback, dtype, *args, excluded=None, **kwargs):
|
|
42
|
-
""" version of jax.pure_callback that deals correctly with ufuncs,
|
|
43
|
-
see https://github.com/google/jax/issues/17187 """
|
|
49
|
+
""" version of `jax.pure_callback` that deals correctly with ufuncs,
|
|
50
|
+
see `<https://github.com/google/jax/issues/17187>`_ """
|
|
44
51
|
if excluded is None:
|
|
45
52
|
excluded = ()
|
|
46
53
|
shape = jnp.broadcast_shapes(*(
|
|
@@ -63,6 +70,7 @@ class scipy:
|
|
|
63
70
|
|
|
64
71
|
class special:
|
|
65
72
|
|
|
73
|
+
@functools.wraps(special.gammainccinv)
|
|
66
74
|
def gammainccinv(a, y):
|
|
67
75
|
a = jnp.asarray(a)
|
|
68
76
|
y = jnp.asarray(y)
|
|
@@ -73,13 +81,261 @@ class scipy:
|
|
|
73
81
|
class stats:
|
|
74
82
|
|
|
75
83
|
class invgamma:
|
|
76
|
-
|
|
84
|
+
|
|
77
85
|
def ppf(q, a):
|
|
78
86
|
return 1 / scipy.special.gammainccinv(a, q)
|
|
79
87
|
|
|
80
88
|
@functools.wraps(jax.vmap)
|
|
81
89
|
def vmap_nodoc(fun, *args, **kw):
|
|
90
|
+
"""
|
|
91
|
+
Version of `jax.vmap` that preserves the docstring of the input function.
|
|
92
|
+
"""
|
|
82
93
|
doc = fun.__doc__
|
|
83
94
|
fun = jax.vmap(fun, *args, **kw)
|
|
84
95
|
fun.__doc__ = doc
|
|
85
96
|
return fun
|
|
97
|
+
|
|
98
|
+
def huge_value(x):
|
|
99
|
+
"""
|
|
100
|
+
Return the maximum value that can be stored in `x`.
|
|
101
|
+
|
|
102
|
+
Parameters
|
|
103
|
+
----------
|
|
104
|
+
x : array
|
|
105
|
+
A numerical numpy or jax array.
|
|
106
|
+
|
|
107
|
+
Returns
|
|
108
|
+
-------
|
|
109
|
+
maxval : scalar
|
|
110
|
+
The maximum value allowed by `x`'s type (+inf for floats).
|
|
111
|
+
"""
|
|
112
|
+
if jnp.issubdtype(x.dtype, jnp.integer):
|
|
113
|
+
return jnp.iinfo(x.dtype).max
|
|
114
|
+
else:
|
|
115
|
+
return jnp.inf
|
|
116
|
+
|
|
117
|
+
def minimal_unsigned_dtype(max_value):
|
|
118
|
+
"""
|
|
119
|
+
Return the smallest unsigned integer dtype that can represent a given
|
|
120
|
+
maximum value (inclusive).
|
|
121
|
+
"""
|
|
122
|
+
if max_value < 2 ** 8:
|
|
123
|
+
return jnp.uint8
|
|
124
|
+
if max_value < 2 ** 16:
|
|
125
|
+
return jnp.uint16
|
|
126
|
+
if max_value < 2 ** 32:
|
|
127
|
+
return jnp.uint32
|
|
128
|
+
return jnp.uint64
|
|
129
|
+
|
|
130
|
+
def signed_to_unsigned(int_dtype):
|
|
131
|
+
"""
|
|
132
|
+
Map a signed integer type to its unsigned counterpart. Unsigned types are
|
|
133
|
+
passed through.
|
|
134
|
+
"""
|
|
135
|
+
assert jnp.issubdtype(int_dtype, jnp.integer)
|
|
136
|
+
if jnp.issubdtype(int_dtype, jnp.unsignedinteger):
|
|
137
|
+
return int_dtype
|
|
138
|
+
if int_dtype == jnp.int8:
|
|
139
|
+
return jnp.uint8
|
|
140
|
+
if int_dtype == jnp.int16:
|
|
141
|
+
return jnp.uint16
|
|
142
|
+
if int_dtype == jnp.int32:
|
|
143
|
+
return jnp.uint32
|
|
144
|
+
if int_dtype == jnp.int64:
|
|
145
|
+
return jnp.uint64
|
|
146
|
+
|
|
147
|
+
def ensure_unsigned(x):
|
|
148
|
+
"""
|
|
149
|
+
If x has signed integer type, cast it to the unsigned dtype of the same size.
|
|
150
|
+
"""
|
|
151
|
+
return x.astype(signed_to_unsigned(x.dtype))
|
|
152
|
+
|
|
153
|
+
@functools.partial(jax.jit, static_argnums=(1,))
|
|
154
|
+
def unique(x, size, fill_value):
|
|
155
|
+
"""
|
|
156
|
+
Restricted version of `jax.numpy.unique` that uses less memory.
|
|
157
|
+
|
|
158
|
+
Parameters
|
|
159
|
+
----------
|
|
160
|
+
x : 1d array
|
|
161
|
+
The input array.
|
|
162
|
+
size : int
|
|
163
|
+
The length of the output.
|
|
164
|
+
fill_value : scalar
|
|
165
|
+
The value to fill the output with if `size` is greater than the number
|
|
166
|
+
of unique values in `x`.
|
|
167
|
+
|
|
168
|
+
Returns
|
|
169
|
+
-------
|
|
170
|
+
out : array (size,)
|
|
171
|
+
The unique values in `x`, sorted, and right-padded with `fill_value`.
|
|
172
|
+
actual_length : int
|
|
173
|
+
The number of used values in `out`.
|
|
174
|
+
"""
|
|
175
|
+
if x.size == 0:
|
|
176
|
+
return jnp.full(size, fill_value, x.dtype), 0
|
|
177
|
+
if size == 0:
|
|
178
|
+
return jnp.empty(0, x.dtype), 0
|
|
179
|
+
x = jnp.sort(x)
|
|
180
|
+
def loop(carry, x):
|
|
181
|
+
i_out, i_in, last, out = carry
|
|
182
|
+
i_out = jnp.where(x == last, i_out, i_out + 1)
|
|
183
|
+
out = out.at[i_out].set(x)
|
|
184
|
+
return (i_out, i_in + 1, x, out), None
|
|
185
|
+
carry = 0, 0, x[0], jnp.full(size, fill_value, x.dtype)
|
|
186
|
+
(actual_length, _, _, out), _ = jax.lax.scan(loop, carry, x[:size])
|
|
187
|
+
return out, actual_length + 1
|
|
188
|
+
|
|
189
|
+
def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False):
|
|
190
|
+
"""
|
|
191
|
+
Batch a function such that each batch is smaller than a threshold.
|
|
192
|
+
|
|
193
|
+
Parameters
|
|
194
|
+
----------
|
|
195
|
+
func : callable
|
|
196
|
+
A jittable function with positional arguments only, with inputs and
|
|
197
|
+
outputs pytrees of arrays.
|
|
198
|
+
max_io_nbytes : int
|
|
199
|
+
The maximum number of input + output bytes in each batch.
|
|
200
|
+
in_axes : pytree of ints, default 0
|
|
201
|
+
A tree matching the structure of the function input, indicating along
|
|
202
|
+
which axes each array should be batched. If a single integer, it is
|
|
203
|
+
used for all arrays.
|
|
204
|
+
out_axes : pytree of ints, default 0
|
|
205
|
+
The same for outputs.
|
|
206
|
+
return_nbatches : bool, default False
|
|
207
|
+
If True, the number of batches is returned as a second output.
|
|
208
|
+
|
|
209
|
+
Returns
|
|
210
|
+
-------
|
|
211
|
+
batched_func : callable
|
|
212
|
+
A function with the same signature as `func`, but that processes the
|
|
213
|
+
input and output in batches in a loop.
|
|
214
|
+
"""
|
|
215
|
+
|
|
216
|
+
def expand_axes(axes, tree):
|
|
217
|
+
if isinstance(axes, int):
|
|
218
|
+
return tree_util.tree_map(lambda _: axes, tree)
|
|
219
|
+
return tree_util.tree_map(lambda _, axis: axis, tree, axes)
|
|
220
|
+
|
|
221
|
+
def extract_size(axes, tree):
|
|
222
|
+
sizes = tree_util.tree_map(lambda x, axis: x.shape[axis], tree, axes)
|
|
223
|
+
sizes, _ = tree_util.tree_flatten(sizes)
|
|
224
|
+
assert all(s == sizes[0] for s in sizes)
|
|
225
|
+
return sizes[0]
|
|
226
|
+
|
|
227
|
+
def sum_nbytes(tree):
|
|
228
|
+
def nbytes(x):
|
|
229
|
+
return math.prod(x.shape) * x.dtype.itemsize
|
|
230
|
+
return tree_util.tree_reduce(lambda size, x: size + nbytes(x), tree, 0)
|
|
231
|
+
|
|
232
|
+
def next_divisor_small(dividend, min_divisor):
|
|
233
|
+
for divisor in range(min_divisor, int(math.sqrt(dividend)) + 1):
|
|
234
|
+
if dividend % divisor == 0:
|
|
235
|
+
return divisor
|
|
236
|
+
return dividend
|
|
237
|
+
|
|
238
|
+
def next_divisor_large(dividend, min_divisor):
|
|
239
|
+
max_inv_divisor = dividend // min_divisor
|
|
240
|
+
for inv_divisor in range(max_inv_divisor, 0, -1):
|
|
241
|
+
if dividend % inv_divisor == 0:
|
|
242
|
+
return dividend // inv_divisor
|
|
243
|
+
return dividend
|
|
244
|
+
|
|
245
|
+
def next_divisor(dividend, min_divisor):
|
|
246
|
+
if min_divisor * min_divisor <= dividend:
|
|
247
|
+
return next_divisor_small(dividend, min_divisor)
|
|
248
|
+
return next_divisor_large(dividend, min_divisor)
|
|
249
|
+
|
|
250
|
+
def move_axes_out(axes, tree):
|
|
251
|
+
def move_axis_out(axis, x):
|
|
252
|
+
if axis != 0:
|
|
253
|
+
return jnp.moveaxis(x, axis, 0)
|
|
254
|
+
return x
|
|
255
|
+
return tree_util.tree_map(move_axis_out, axes, tree)
|
|
256
|
+
|
|
257
|
+
def move_axes_in(axes, tree):
|
|
258
|
+
def move_axis_in(axis, x):
|
|
259
|
+
if axis != 0:
|
|
260
|
+
return jnp.moveaxis(x, 0, axis)
|
|
261
|
+
return x
|
|
262
|
+
return tree_util.tree_map(move_axis_in, axes, tree)
|
|
263
|
+
|
|
264
|
+
def batch(tree, nbatches):
|
|
265
|
+
def batch(x):
|
|
266
|
+
return x.reshape((nbatches, x.shape[0] // nbatches) + x.shape[1:])
|
|
267
|
+
return tree_util.tree_map(batch, tree)
|
|
268
|
+
|
|
269
|
+
def unbatch(tree):
|
|
270
|
+
def unbatch(x):
|
|
271
|
+
return x.reshape((x.shape[0] * x.shape[1],) + x.shape[2:])
|
|
272
|
+
return tree_util.tree_map(unbatch, tree)
|
|
273
|
+
|
|
274
|
+
def check_same(tree1, tree2):
|
|
275
|
+
def check_same(x1, x2):
|
|
276
|
+
assert x1.shape == x2.shape
|
|
277
|
+
assert x1.dtype == x2.dtype
|
|
278
|
+
tree_util.tree_map(check_same, tree1, tree2)
|
|
279
|
+
|
|
280
|
+
initial_in_axes = in_axes
|
|
281
|
+
initial_out_axes = out_axes
|
|
282
|
+
|
|
283
|
+
@jax.jit
|
|
284
|
+
@functools.wraps(func)
|
|
285
|
+
def batched_func(*args):
|
|
286
|
+
example_result = jax.eval_shape(func, *args)
|
|
287
|
+
|
|
288
|
+
in_axes = expand_axes(initial_in_axes, args)
|
|
289
|
+
out_axes = expand_axes(initial_out_axes, example_result)
|
|
290
|
+
|
|
291
|
+
in_size = extract_size(in_axes, args)
|
|
292
|
+
out_size = extract_size(out_axes, example_result)
|
|
293
|
+
assert in_size == out_size
|
|
294
|
+
size = in_size
|
|
295
|
+
|
|
296
|
+
total_nbytes = sum_nbytes(args) + sum_nbytes(example_result)
|
|
297
|
+
min_nbatches = total_nbytes // max_io_nbytes + bool(total_nbytes % max_io_nbytes)
|
|
298
|
+
nbatches = next_divisor(size, min_nbatches)
|
|
299
|
+
assert 1 <= nbatches <= size
|
|
300
|
+
assert size % nbatches == 0
|
|
301
|
+
assert total_nbytes % nbatches == 0
|
|
302
|
+
|
|
303
|
+
batch_nbytes = total_nbytes // nbatches
|
|
304
|
+
if batch_nbytes > max_io_nbytes:
|
|
305
|
+
assert size == nbatches
|
|
306
|
+
warnings.warn(f'batch_nbytes = {batch_nbytes} > max_io_nbytes = {max_io_nbytes}')
|
|
307
|
+
|
|
308
|
+
def loop(_, args):
|
|
309
|
+
args = move_axes_in(in_axes, args)
|
|
310
|
+
result = func(*args)
|
|
311
|
+
result = move_axes_out(out_axes, result)
|
|
312
|
+
return None, result
|
|
313
|
+
|
|
314
|
+
args = move_axes_out(in_axes, args)
|
|
315
|
+
args = batch(args, nbatches)
|
|
316
|
+
_, result = lax.scan(loop, None, args)
|
|
317
|
+
result = unbatch(result)
|
|
318
|
+
result = move_axes_in(out_axes, result)
|
|
319
|
+
|
|
320
|
+
check_same(example_result, result)
|
|
321
|
+
|
|
322
|
+
if return_nbatches:
|
|
323
|
+
return result, nbatches
|
|
324
|
+
return result
|
|
325
|
+
|
|
326
|
+
return batched_func
|
|
327
|
+
|
|
328
|
+
@tree_util.register_pytree_node_class
|
|
329
|
+
class LeafDict(dict):
|
|
330
|
+
""" dictionary that acts as a leaf in jax pytrees, to store compile-time
|
|
331
|
+
values """
|
|
332
|
+
|
|
333
|
+
def tree_flatten(self):
|
|
334
|
+
return (), self
|
|
335
|
+
|
|
336
|
+
@classmethod
|
|
337
|
+
def tree_unflatten(cls, aux_data, children):
|
|
338
|
+
return aux_data
|
|
339
|
+
|
|
340
|
+
def __repr__(self):
|
|
341
|
+
return f'{__class__.__name__}({super().__repr__()})'
|
bartz/mcmcloop.py
CHANGED
|
@@ -52,7 +52,7 @@ def run_mcmc(bart, n_burn, n_save, n_skip, callback, key):
|
|
|
52
52
|
n_save : int
|
|
53
53
|
The number of iterations to save.
|
|
54
54
|
n_skip : int
|
|
55
|
-
The number of iterations to skip between each saved iteration.
|
|
55
|
+
The number of iterations to skip between each saved iteration, plus 1.
|
|
56
56
|
callback : callable
|
|
57
57
|
An arbitrary function run at each iteration, called with the following
|
|
58
58
|
arguments, passed by keyword:
|
|
@@ -100,15 +100,24 @@ def run_mcmc(bart, n_burn, n_save, n_skip, callback, key):
|
|
|
100
100
|
def inner_loop(carry, _, tracelist, burnin):
|
|
101
101
|
bart, i_total, i_skip, key = carry
|
|
102
102
|
key, subkey = random.split(key)
|
|
103
|
-
bart = mcmcstep.
|
|
103
|
+
bart = mcmcstep.step(bart, subkey)
|
|
104
104
|
callback(bart=bart, burnin=burnin, i_total=i_total, i_skip=i_skip, **callback_kw)
|
|
105
105
|
output = {key: bart[key] for key in tracelist}
|
|
106
106
|
return (bart, i_total + 1, i_skip + 1, key), output
|
|
107
107
|
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
108
|
+
def empty_trace(bart, tracelist):
|
|
109
|
+
return {
|
|
110
|
+
key: jnp.empty((0,) + bart[key].shape, bart[key].dtype)
|
|
111
|
+
for key in tracelist
|
|
112
|
+
}
|
|
113
|
+
|
|
114
|
+
if n_burn > 0:
|
|
115
|
+
carry = bart, 0, 0, key
|
|
116
|
+
burnin_loop = functools.partial(inner_loop, tracelist=tracelist_burnin, burnin=True)
|
|
117
|
+
(bart, i_total, _, key), burnin_trace = lax.scan(burnin_loop, carry, None, n_burn)
|
|
118
|
+
else:
|
|
119
|
+
i_total = 0
|
|
120
|
+
burnin_trace = empty_trace(bart, tracelist_burnin)
|
|
112
121
|
|
|
113
122
|
def outer_loop(carry, _):
|
|
114
123
|
bart, i_total, key = carry
|
|
@@ -118,8 +127,11 @@ def run_mcmc(bart, n_burn, n_save, n_skip, callback, key):
|
|
|
118
127
|
output = {key: bart[key] for key in tracelist_main}
|
|
119
128
|
return (bart, i_total, key), output
|
|
120
129
|
|
|
121
|
-
|
|
122
|
-
|
|
130
|
+
if n_save > 0:
|
|
131
|
+
carry = bart, i_total, key
|
|
132
|
+
(bart, _, _), main_trace = lax.scan(outer_loop, carry, None, n_save)
|
|
133
|
+
else:
|
|
134
|
+
main_trace = empty_trace(bart, tracelist_main)
|
|
123
135
|
|
|
124
136
|
return bart, burnin_trace, main_trace
|
|
125
137
|
|
|
@@ -127,7 +139,8 @@ def run_mcmc(bart, n_burn, n_save, n_skip, callback, key):
|
|
|
127
139
|
|
|
128
140
|
@functools.lru_cache
|
|
129
141
|
# cache to make the callback function object unique, such that the jit
|
|
130
|
-
# of run_mcmc recognizes it
|
|
142
|
+
# of run_mcmc recognizes it => with the callback state, I can make
|
|
143
|
+
# printevery a runtime quantity
|
|
131
144
|
def make_simple_print_callback(printevery):
|
|
132
145
|
"""
|
|
133
146
|
Create a logging callback function for MCMC iterations.
|
|
@@ -149,11 +162,12 @@ def make_simple_print_callback(printevery):
|
|
|
149
162
|
grow_acc = bart['grow_acc_count'] / bart['grow_prop_count']
|
|
150
163
|
prune_acc = bart['prune_acc_count'] / bart['prune_prop_count']
|
|
151
164
|
n_total = n_burn + n_save * n_skip
|
|
152
|
-
|
|
165
|
+
printcond = (i_total + 1) % printevery == 0
|
|
166
|
+
debug.callback(_simple_print_callback, burnin, i_total, n_total, grow_prop, grow_acc, prune_prop, prune_acc, printcond)
|
|
153
167
|
return callback
|
|
154
168
|
|
|
155
|
-
def
|
|
156
|
-
if
|
|
169
|
+
def _simple_print_callback(burnin, i_total, n_total, grow_prop, grow_acc, prune_prop, prune_acc, printcond):
|
|
170
|
+
if printcond:
|
|
157
171
|
burnin_flag = ' (burnin)' if burnin else ''
|
|
158
172
|
total_str = str(n_total)
|
|
159
173
|
ndigits = len(total_str)
|
|
@@ -180,6 +194,6 @@ def evaluate_trace(trace, X):
|
|
|
180
194
|
The predictions for each iteration of the MCMC.
|
|
181
195
|
"""
|
|
182
196
|
def loop(_, state):
|
|
183
|
-
return None, grove.
|
|
197
|
+
return None, grove.evaluate_forest(X, state['leaf_trees'], state['var_trees'], state['split_trees'], jnp.float32)
|
|
184
198
|
_, y = lax.scan(loop, None, trace)
|
|
185
199
|
return y
|