bartz 0.5.0__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bartz/jaxext.py CHANGED
@@ -22,25 +22,25 @@
22
22
  # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23
23
  # SOFTWARE.
24
24
 
25
+ """Additions to jax."""
26
+
25
27
  import functools
26
28
  import math
27
29
  import warnings
28
30
 
29
31
  import jax
30
- from jax import lax, tree_util
32
+ from jax import lax, random, tree_util
31
33
  from jax import numpy as jnp
32
34
  from scipy import special
33
35
 
34
36
 
35
37
  def float_type(*args):
36
- """
37
- Determine the jax floating point result type given operands/types.
38
- """
38
+ """Determine the jax floating point result type given operands/types."""
39
39
  t = jnp.result_type(*args)
40
40
  return jnp.sin(jnp.empty(0, t)).dtype
41
41
 
42
42
 
43
- def castto(func, type):
43
+ def _castto(func, type):
44
44
  @functools.wraps(func)
45
45
  def newfunc(*args, **kw):
46
46
  return func(*args, **kw).astype(type)
@@ -49,26 +49,37 @@ def castto(func, type):
49
49
 
50
50
 
51
51
  class scipy:
52
+ """Mockup of the :external:py:mod:`scipy` module."""
53
+
52
54
  class special:
53
- @functools.wraps(special.gammainccinv)
55
+ """Mockup of the :external:py:mod:`scipy.special` module."""
56
+
57
+ @staticmethod
54
58
  def gammainccinv(a, y):
59
+ """Survival function inverse of the Gamma(a, 1) distribution."""
55
60
  a = jnp.asarray(a)
56
61
  y = jnp.asarray(y)
57
62
  shape = jnp.broadcast_shapes(a.shape, y.shape)
58
63
  dtype = float_type(a.dtype, y.dtype)
59
64
  dummy = jax.ShapeDtypeStruct(shape, dtype)
60
- ufunc = castto(special.gammainccinv, dtype)
65
+ ufunc = _castto(special.gammainccinv, dtype)
61
66
  return jax.pure_callback(ufunc, dummy, a, y, vmap_method='expand_dims')
62
67
 
63
68
  class stats:
69
+ """Mockup of the :external:py:mod:`scipy.stats` module."""
70
+
64
71
  class invgamma:
72
+ """Class that represents the distribution InvGamma(a, 1)."""
73
+
74
+ @staticmethod
65
75
  def ppf(q, a):
76
+ """Percentile point function."""
66
77
  return 1 / scipy.special.gammainccinv(a, q)
67
78
 
68
79
 
69
80
  def vmap_nodoc(fun, *args, **kw):
70
81
  """
71
- Wrapper of `jax.vmap` that preserves the docstring of the input function.
82
+ Acts like `jax.vmap` but preserves the docstring of the function unchanged.
72
83
 
73
84
  This is useful if the docstring already takes into account that the
74
85
  arguments have additional axes due to vmap.
@@ -99,24 +110,22 @@ def huge_value(x):
99
110
  return jnp.inf
100
111
 
101
112
 
102
- def minimal_unsigned_dtype(max_value):
103
- """
104
- Return the smallest unsigned integer dtype that can represent a given
105
- maximum value (inclusive).
106
- """
107
- if max_value < 2**8:
113
+ def minimal_unsigned_dtype(value):
114
+ """Return the smallest unsigned integer dtype that can represent `value`."""
115
+ if value < 2**8:
108
116
  return jnp.uint8
109
- if max_value < 2**16:
117
+ if value < 2**16:
110
118
  return jnp.uint16
111
- if max_value < 2**32:
119
+ if value < 2**32:
112
120
  return jnp.uint32
113
121
  return jnp.uint64
114
122
 
115
123
 
116
124
  def signed_to_unsigned(int_dtype):
117
125
  """
118
- Map a signed integer type to its unsigned counterpart. Unsigned types are
119
- passed through.
126
+ Map a signed integer type to its unsigned counterpart.
127
+
128
+ Unsigned types are passed through.
120
129
  """
121
130
  assert jnp.issubdtype(int_dtype, jnp.integer)
122
131
  if jnp.issubdtype(int_dtype, jnp.unsignedinteger):
@@ -132,9 +141,7 @@ def signed_to_unsigned(int_dtype):
132
141
 
133
142
 
134
143
  def ensure_unsigned(x):
135
- """
136
- If x has signed integer type, cast it to the unsigned dtype of the same size.
137
- """
144
+ """If x has signed integer type, cast it to the unsigned dtype of the same size."""
138
145
  return x.astype(signed_to_unsigned(x.dtype))
139
146
 
140
147
 
@@ -358,17 +365,59 @@ def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False)
358
365
  return batched_func
359
366
 
360
367
 
361
- @tree_util.register_pytree_node_class
362
- class LeafDict(dict):
363
- """dictionary that acts as a leaf in jax pytrees, to store compile-time
364
- values"""
365
-
366
- def tree_flatten(self):
367
- return (), self
368
+ class split:
369
+ """
370
+ Split a key into `num` keys.
368
371
 
369
- @classmethod
370
- def tree_unflatten(cls, aux_data, children):
371
- return aux_data
372
+ Parameters
373
+ ----------
374
+ key : jax.dtypes.prng_key array
375
+ The key to split.
376
+ num : int
377
+ The number of keys to split into.
378
+ """
372
379
 
373
- def __repr__(self):
374
- return f'{__class__.__name__}({super().__repr__()})'
380
+ def __init__(self, key, num=2):
381
+ self._keys = random.split(key, num)
382
+
383
+ def __len__(self):
384
+ return self._keys.size
385
+
386
+ def pop(self, shape=None):
387
+ """
388
+ Pop one or more keys from the list.
389
+
390
+ Parameters
391
+ ----------
392
+ shape : int or tuple of int, optional
393
+ The shape of the keys to pop. If `None`, a single key is popped.
394
+ If an integer, that many keys are popped. If a tuple, the keys are
395
+ reshaped to that shape.
396
+
397
+ Returns
398
+ -------
399
+ keys : jax.dtypes.prng_key array
400
+ The popped keys.
401
+
402
+ Raises
403
+ ------
404
+ IndexError
405
+ If `shape` is larger than the number of keys left in the list.
406
+
407
+ Notes
408
+ -----
409
+ The keys are popped from the beginning of the list, so for example
410
+ ``list(keys.pop(2))`` is equivalent to ``[keys.pop(), keys.pop()]``.
411
+ """
412
+ if shape is None:
413
+ shape = ()
414
+ elif not isinstance(shape, tuple):
415
+ shape = (shape,)
416
+ size_to_pop = math.prod(shape)
417
+ if size_to_pop > self._keys.size:
418
+ raise IndexError(
419
+ f'Cannot pop {size_to_pop} keys from {self._keys.size} keys'
420
+ )
421
+ popped_keys = self._keys[:size_to_pop]
422
+ self._keys = self._keys[size_to_pop:]
423
+ return popped_keys.reshape(shape)