bartz 0.5.0__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bartz/BART.py +196 -103
- bartz/__init__.py +1 -1
- bartz/_version.py +1 -1
- bartz/debug.py +1 -1
- bartz/grove.py +43 -2
- bartz/jaxext.py +82 -33
- bartz/mcmcloop.py +367 -114
- bartz/mcmcstep.py +1322 -807
- bartz/prepcovars.py +3 -1
- {bartz-0.5.0.dist-info → bartz-0.6.0.dist-info}/METADATA +7 -5
- bartz-0.6.0.dist-info/RECORD +13 -0
- {bartz-0.5.0.dist-info → bartz-0.6.0.dist-info}/WHEEL +1 -1
- bartz-0.5.0.dist-info/RECORD +0 -13
bartz/jaxext.py
CHANGED
|
@@ -22,25 +22,25 @@
|
|
|
22
22
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
23
23
|
# SOFTWARE.
|
|
24
24
|
|
|
25
|
+
"""Additions to jax."""
|
|
26
|
+
|
|
25
27
|
import functools
|
|
26
28
|
import math
|
|
27
29
|
import warnings
|
|
28
30
|
|
|
29
31
|
import jax
|
|
30
|
-
from jax import lax, tree_util
|
|
32
|
+
from jax import lax, random, tree_util
|
|
31
33
|
from jax import numpy as jnp
|
|
32
34
|
from scipy import special
|
|
33
35
|
|
|
34
36
|
|
|
35
37
|
def float_type(*args):
|
|
36
|
-
"""
|
|
37
|
-
Determine the jax floating point result type given operands/types.
|
|
38
|
-
"""
|
|
38
|
+
"""Determine the jax floating point result type given operands/types."""
|
|
39
39
|
t = jnp.result_type(*args)
|
|
40
40
|
return jnp.sin(jnp.empty(0, t)).dtype
|
|
41
41
|
|
|
42
42
|
|
|
43
|
-
def
|
|
43
|
+
def _castto(func, type):
|
|
44
44
|
@functools.wraps(func)
|
|
45
45
|
def newfunc(*args, **kw):
|
|
46
46
|
return func(*args, **kw).astype(type)
|
|
@@ -49,26 +49,37 @@ def castto(func, type):
|
|
|
49
49
|
|
|
50
50
|
|
|
51
51
|
class scipy:
|
|
52
|
+
"""Mockup of the :external:py:mod:`scipy` module."""
|
|
53
|
+
|
|
52
54
|
class special:
|
|
53
|
-
|
|
55
|
+
"""Mockup of the :external:py:mod:`scipy.special` module."""
|
|
56
|
+
|
|
57
|
+
@staticmethod
|
|
54
58
|
def gammainccinv(a, y):
|
|
59
|
+
"""Survival function inverse of the Gamma(a, 1) distribution."""
|
|
55
60
|
a = jnp.asarray(a)
|
|
56
61
|
y = jnp.asarray(y)
|
|
57
62
|
shape = jnp.broadcast_shapes(a.shape, y.shape)
|
|
58
63
|
dtype = float_type(a.dtype, y.dtype)
|
|
59
64
|
dummy = jax.ShapeDtypeStruct(shape, dtype)
|
|
60
|
-
ufunc =
|
|
65
|
+
ufunc = _castto(special.gammainccinv, dtype)
|
|
61
66
|
return jax.pure_callback(ufunc, dummy, a, y, vmap_method='expand_dims')
|
|
62
67
|
|
|
63
68
|
class stats:
|
|
69
|
+
"""Mockup of the :external:py:mod:`scipy.stats` module."""
|
|
70
|
+
|
|
64
71
|
class invgamma:
|
|
72
|
+
"""Class that represents the distribution InvGamma(a, 1)."""
|
|
73
|
+
|
|
74
|
+
@staticmethod
|
|
65
75
|
def ppf(q, a):
|
|
76
|
+
"""Percentile point function."""
|
|
66
77
|
return 1 / scipy.special.gammainccinv(a, q)
|
|
67
78
|
|
|
68
79
|
|
|
69
80
|
def vmap_nodoc(fun, *args, **kw):
|
|
70
81
|
"""
|
|
71
|
-
|
|
82
|
+
Acts like `jax.vmap` but preserves the docstring of the function unchanged.
|
|
72
83
|
|
|
73
84
|
This is useful if the docstring already takes into account that the
|
|
74
85
|
arguments have additional axes due to vmap.
|
|
@@ -99,24 +110,22 @@ def huge_value(x):
|
|
|
99
110
|
return jnp.inf
|
|
100
111
|
|
|
101
112
|
|
|
102
|
-
def minimal_unsigned_dtype(
|
|
103
|
-
"""
|
|
104
|
-
|
|
105
|
-
maximum value (inclusive).
|
|
106
|
-
"""
|
|
107
|
-
if max_value < 2**8:
|
|
113
|
+
def minimal_unsigned_dtype(value):
|
|
114
|
+
"""Return the smallest unsigned integer dtype that can represent `value`."""
|
|
115
|
+
if value < 2**8:
|
|
108
116
|
return jnp.uint8
|
|
109
|
-
if
|
|
117
|
+
if value < 2**16:
|
|
110
118
|
return jnp.uint16
|
|
111
|
-
if
|
|
119
|
+
if value < 2**32:
|
|
112
120
|
return jnp.uint32
|
|
113
121
|
return jnp.uint64
|
|
114
122
|
|
|
115
123
|
|
|
116
124
|
def signed_to_unsigned(int_dtype):
|
|
117
125
|
"""
|
|
118
|
-
Map a signed integer type to its unsigned counterpart.
|
|
119
|
-
|
|
126
|
+
Map a signed integer type to its unsigned counterpart.
|
|
127
|
+
|
|
128
|
+
Unsigned types are passed through.
|
|
120
129
|
"""
|
|
121
130
|
assert jnp.issubdtype(int_dtype, jnp.integer)
|
|
122
131
|
if jnp.issubdtype(int_dtype, jnp.unsignedinteger):
|
|
@@ -132,9 +141,7 @@ def signed_to_unsigned(int_dtype):
|
|
|
132
141
|
|
|
133
142
|
|
|
134
143
|
def ensure_unsigned(x):
|
|
135
|
-
"""
|
|
136
|
-
If x has signed integer type, cast it to the unsigned dtype of the same size.
|
|
137
|
-
"""
|
|
144
|
+
"""If x has signed integer type, cast it to the unsigned dtype of the same size."""
|
|
138
145
|
return x.astype(signed_to_unsigned(x.dtype))
|
|
139
146
|
|
|
140
147
|
|
|
@@ -358,17 +365,59 @@ def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False)
|
|
|
358
365
|
return batched_func
|
|
359
366
|
|
|
360
367
|
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
values"""
|
|
365
|
-
|
|
366
|
-
def tree_flatten(self):
|
|
367
|
-
return (), self
|
|
368
|
+
class split:
|
|
369
|
+
"""
|
|
370
|
+
Split a key into `num` keys.
|
|
368
371
|
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
+
Parameters
|
|
373
|
+
----------
|
|
374
|
+
key : jax.dtypes.prng_key array
|
|
375
|
+
The key to split.
|
|
376
|
+
num : int
|
|
377
|
+
The number of keys to split into.
|
|
378
|
+
"""
|
|
372
379
|
|
|
373
|
-
def
|
|
374
|
-
|
|
380
|
+
def __init__(self, key, num=2):
|
|
381
|
+
self._keys = random.split(key, num)
|
|
382
|
+
|
|
383
|
+
def __len__(self):
|
|
384
|
+
return self._keys.size
|
|
385
|
+
|
|
386
|
+
def pop(self, shape=None):
|
|
387
|
+
"""
|
|
388
|
+
Pop one or more keys from the list.
|
|
389
|
+
|
|
390
|
+
Parameters
|
|
391
|
+
----------
|
|
392
|
+
shape : int or tuple of int, optional
|
|
393
|
+
The shape of the keys to pop. If `None`, a single key is popped.
|
|
394
|
+
If an integer, that many keys are popped. If a tuple, the keys are
|
|
395
|
+
reshaped to that shape.
|
|
396
|
+
|
|
397
|
+
Returns
|
|
398
|
+
-------
|
|
399
|
+
keys : jax.dtypes.prng_key array
|
|
400
|
+
The popped keys.
|
|
401
|
+
|
|
402
|
+
Raises
|
|
403
|
+
------
|
|
404
|
+
IndexError
|
|
405
|
+
If `shape` is larger than the number of keys left in the list.
|
|
406
|
+
|
|
407
|
+
Notes
|
|
408
|
+
-----
|
|
409
|
+
The keys are popped from the beginning of the list, so for example
|
|
410
|
+
``list(keys.pop(2))`` is equivalent to ``[keys.pop(), keys.pop()]``.
|
|
411
|
+
"""
|
|
412
|
+
if shape is None:
|
|
413
|
+
shape = ()
|
|
414
|
+
elif not isinstance(shape, tuple):
|
|
415
|
+
shape = (shape,)
|
|
416
|
+
size_to_pop = math.prod(shape)
|
|
417
|
+
if size_to_pop > self._keys.size:
|
|
418
|
+
raise IndexError(
|
|
419
|
+
f'Cannot pop {size_to_pop} keys from {self._keys.size} keys'
|
|
420
|
+
)
|
|
421
|
+
popped_keys = self._keys[:size_to_pop]
|
|
422
|
+
self._keys = self._keys[size_to_pop:]
|
|
423
|
+
return popped_keys.reshape(shape)
|