brainstate 0.0.2.post20241009__py2.py3-none-any.whl → 0.1.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. brainstate/__init__.py +31 -11
  2. brainstate/_state.py +760 -316
  3. brainstate/_state_test.py +41 -12
  4. brainstate/_utils.py +31 -4
  5. brainstate/augment/__init__.py +40 -0
  6. brainstate/augment/_autograd.py +608 -0
  7. brainstate/augment/_autograd_test.py +1193 -0
  8. brainstate/augment/_eval_shape.py +102 -0
  9. brainstate/augment/_eval_shape_test.py +40 -0
  10. brainstate/augment/_mapping.py +525 -0
  11. brainstate/augment/_mapping_test.py +210 -0
  12. brainstate/augment/_random.py +99 -0
  13. brainstate/{transform → compile}/__init__.py +25 -13
  14. brainstate/compile/_ad_checkpoint.py +204 -0
  15. brainstate/compile/_ad_checkpoint_test.py +51 -0
  16. brainstate/compile/_conditions.py +259 -0
  17. brainstate/compile/_conditions_test.py +221 -0
  18. brainstate/compile/_error_if.py +94 -0
  19. brainstate/compile/_error_if_test.py +54 -0
  20. brainstate/compile/_jit.py +314 -0
  21. brainstate/compile/_jit_test.py +143 -0
  22. brainstate/compile/_loop_collect_return.py +516 -0
  23. brainstate/compile/_loop_collect_return_test.py +59 -0
  24. brainstate/compile/_loop_no_collection.py +185 -0
  25. brainstate/compile/_loop_no_collection_test.py +51 -0
  26. brainstate/compile/_make_jaxpr.py +756 -0
  27. brainstate/compile/_make_jaxpr_test.py +134 -0
  28. brainstate/compile/_progress_bar.py +111 -0
  29. brainstate/compile/_unvmap.py +159 -0
  30. brainstate/compile/_util.py +147 -0
  31. brainstate/environ.py +408 -381
  32. brainstate/environ_test.py +34 -32
  33. brainstate/{nn/event → event}/__init__.py +6 -6
  34. brainstate/event/_csr.py +308 -0
  35. brainstate/event/_csr_test.py +118 -0
  36. brainstate/event/_fixed_probability.py +271 -0
  37. brainstate/event/_fixed_probability_test.py +128 -0
  38. brainstate/event/_linear.py +219 -0
  39. brainstate/event/_linear_test.py +112 -0
  40. brainstate/{nn/event → event}/_misc.py +7 -7
  41. brainstate/functional/_activations.py +521 -511
  42. brainstate/functional/_activations_test.py +300 -300
  43. brainstate/functional/_normalization.py +43 -43
  44. brainstate/functional/_others.py +15 -15
  45. brainstate/functional/_spikes.py +49 -49
  46. brainstate/graph/__init__.py +33 -0
  47. brainstate/graph/_graph_context.py +443 -0
  48. brainstate/graph/_graph_context_test.py +65 -0
  49. brainstate/graph/_graph_convert.py +246 -0
  50. brainstate/graph/_graph_node.py +300 -0
  51. brainstate/graph/_graph_node_test.py +75 -0
  52. brainstate/graph/_graph_operation.py +1746 -0
  53. brainstate/graph/_graph_operation_test.py +724 -0
  54. brainstate/init/_base.py +28 -10
  55. brainstate/init/_generic.py +175 -172
  56. brainstate/init/_random_inits.py +470 -415
  57. brainstate/init/_random_inits_test.py +150 -0
  58. brainstate/init/_regular_inits.py +66 -69
  59. brainstate/init/_regular_inits_test.py +51 -0
  60. brainstate/mixin.py +236 -244
  61. brainstate/mixin_test.py +44 -46
  62. brainstate/nn/__init__.py +26 -51
  63. brainstate/nn/_collective_ops.py +199 -0
  64. brainstate/nn/_dyn_impl/__init__.py +46 -0
  65. brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
  66. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
  67. brainstate/nn/_dyn_impl/_dynamics_synapse.py +320 -0
  68. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
  69. brainstate/nn/_dyn_impl/_inputs.py +154 -0
  70. brainstate/nn/{_projection/__init__.py → _dyn_impl/_projection_alignpost.py} +6 -13
  71. brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
  72. brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
  73. brainstate/nn/_dyn_impl/_readout.py +128 -0
  74. brainstate/nn/_dyn_impl/_readout_test.py +54 -0
  75. brainstate/nn/_dynamics/__init__.py +37 -0
  76. brainstate/nn/_dynamics/_dynamics_base.py +631 -0
  77. brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
  78. brainstate/nn/_dynamics/_projection_base.py +346 -0
  79. brainstate/nn/_dynamics/_state_delay.py +453 -0
  80. brainstate/nn/_dynamics/_synouts.py +161 -0
  81. brainstate/nn/_dynamics/_synouts_test.py +58 -0
  82. brainstate/nn/_elementwise/__init__.py +22 -0
  83. brainstate/nn/_elementwise/_dropout.py +418 -0
  84. brainstate/nn/_elementwise/_dropout_test.py +100 -0
  85. brainstate/nn/_elementwise/_elementwise.py +1122 -0
  86. brainstate/nn/_elementwise/_elementwise_test.py +171 -0
  87. brainstate/nn/_exp_euler.py +97 -0
  88. brainstate/nn/_exp_euler_test.py +36 -0
  89. brainstate/nn/_interaction/__init__.py +32 -0
  90. brainstate/nn/_interaction/_connections.py +726 -0
  91. brainstate/nn/_interaction/_connections_test.py +254 -0
  92. brainstate/nn/_interaction/_embedding.py +59 -0
  93. brainstate/nn/_interaction/_normalizations.py +388 -0
  94. brainstate/nn/_interaction/_normalizations_test.py +75 -0
  95. brainstate/nn/_interaction/_poolings.py +1179 -0
  96. brainstate/nn/_interaction/_poolings_test.py +219 -0
  97. brainstate/nn/_module.py +328 -0
  98. brainstate/nn/_module_test.py +211 -0
  99. brainstate/nn/metrics.py +309 -309
  100. brainstate/optim/__init__.py +14 -2
  101. brainstate/optim/_base.py +66 -0
  102. brainstate/optim/_lr_scheduler.py +363 -400
  103. brainstate/optim/_lr_scheduler_test.py +25 -24
  104. brainstate/optim/_optax_optimizer.py +103 -176
  105. brainstate/optim/_optax_optimizer_test.py +41 -1
  106. brainstate/optim/_sgd_optimizer.py +950 -1025
  107. brainstate/random/_rand_funs.py +3269 -3268
  108. brainstate/random/_rand_funs_test.py +568 -0
  109. brainstate/random/_rand_seed.py +149 -117
  110. brainstate/random/_rand_seed_test.py +50 -0
  111. brainstate/random/_rand_state.py +1360 -1318
  112. brainstate/random/_random_for_unit.py +13 -13
  113. brainstate/surrogate.py +1262 -1243
  114. brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
  115. brainstate/typing.py +157 -130
  116. brainstate/util/__init__.py +52 -0
  117. brainstate/util/_caller.py +100 -0
  118. brainstate/util/_dict.py +734 -0
  119. brainstate/util/_dict_test.py +160 -0
  120. brainstate/util/_error.py +28 -0
  121. brainstate/util/_filter.py +178 -0
  122. brainstate/util/_others.py +497 -0
  123. brainstate/util/_pretty_repr.py +208 -0
  124. brainstate/util/_scaling.py +260 -0
  125. brainstate/util/_struct.py +524 -0
  126. brainstate/util/_tracers.py +75 -0
  127. brainstate/{_visualization.py → util/_visualization.py} +16 -16
  128. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/METADATA +11 -11
  129. brainstate-0.1.0.dist-info/RECORD +135 -0
  130. brainstate/_module.py +0 -1637
  131. brainstate/_module_test.py +0 -207
  132. brainstate/nn/_base.py +0 -251
  133. brainstate/nn/_connections.py +0 -686
  134. brainstate/nn/_dynamics.py +0 -426
  135. brainstate/nn/_elementwise.py +0 -1438
  136. brainstate/nn/_embedding.py +0 -66
  137. brainstate/nn/_misc.py +0 -133
  138. brainstate/nn/_normalizations.py +0 -389
  139. brainstate/nn/_others.py +0 -101
  140. brainstate/nn/_poolings.py +0 -1229
  141. brainstate/nn/_poolings_test.py +0 -231
  142. brainstate/nn/_projection/_align_post.py +0 -546
  143. brainstate/nn/_projection/_align_pre.py +0 -599
  144. brainstate/nn/_projection/_delta.py +0 -241
  145. brainstate/nn/_projection/_vanilla.py +0 -101
  146. brainstate/nn/_rate_rnns.py +0 -410
  147. brainstate/nn/_readout.py +0 -136
  148. brainstate/nn/_synouts.py +0 -166
  149. brainstate/nn/event/csr.py +0 -312
  150. brainstate/nn/event/csr_test.py +0 -118
  151. brainstate/nn/event/fixed_probability.py +0 -276
  152. brainstate/nn/event/fixed_probability_test.py +0 -127
  153. brainstate/nn/event/linear.py +0 -220
  154. brainstate/nn/event/linear_test.py +0 -111
  155. brainstate/random/random_test.py +0 -593
  156. brainstate/transform/_autograd.py +0 -585
  157. brainstate/transform/_autograd_test.py +0 -1181
  158. brainstate/transform/_conditions.py +0 -334
  159. brainstate/transform/_conditions_test.py +0 -220
  160. brainstate/transform/_error_if.py +0 -94
  161. brainstate/transform/_error_if_test.py +0 -55
  162. brainstate/transform/_jit.py +0 -265
  163. brainstate/transform/_jit_test.py +0 -118
  164. brainstate/transform/_loop_collect_return.py +0 -502
  165. brainstate/transform/_loop_no_collection.py +0 -170
  166. brainstate/transform/_make_jaxpr.py +0 -739
  167. brainstate/transform/_make_jaxpr_test.py +0 -131
  168. brainstate/transform/_mapping.py +0 -109
  169. brainstate/transform/_progress_bar.py +0 -111
  170. brainstate/transform/_unvmap.py +0 -143
  171. brainstate/util.py +0 -746
  172. brainstate-0.0.2.post20241009.dist-info/RECORD +0 -87
  173. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/LICENSE +0 -0
  174. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/WHEEL +0 -0
  175. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/top_level.txt +0 -0
@@ -14,7 +14,10 @@
14
14
  # ==============================================================================
15
15
 
16
16
  # -*- coding: utf-8 -*-
17
+ from __future__ import annotations
17
18
 
19
+ from collections import namedtuple
20
+ from functools import partial
18
21
  from operator import index
19
22
  from typing import Optional
20
23
 
@@ -23,11 +26,12 @@ import jax
23
26
  import jax.numpy as jnp
24
27
  import jax.random as jr
25
28
  import numpy as np
26
- from jax import lax, core
29
+ from jax import jit, vmap
30
+ from jax import lax, core, dtypes
27
31
 
28
32
  from brainstate import environ
29
33
  from brainstate._state import State
30
- from brainstate.transform._error_if import jit_error_if
34
+ from brainstate.compile._error_if import jit_error_if
31
35
  from brainstate.typing import DTypeLike, Size, SeedOrKey
32
36
  from ._random_for_unit import uniform_for_unit, permutation_for_unit
33
37
 
@@ -35,1060 +39,1098 @@ __all__ = ['RandomState', 'DEFAULT', ]
35
39
 
36
40
 
37
41
  class RandomState(State):
38
- """RandomState that track the random generator state. """
39
- __slots__ = ()
40
-
41
- def __init__(self, seed_or_key: Optional[SeedOrKey] = None):
42
- """RandomState constructor.
43
-
44
- Parameters
45
- ----------
46
- seed_or_key: int, Array, optional
47
- It can be an integer for initial seed of the random number generator,
48
- or it can be a JAX's PRNKey, which is an array with two elements and `uint32` dtype.
49
- """
50
- with jax.ensure_compile_time_eval():
51
- if seed_or_key is None:
52
- seed_or_key = np.random.randint(0, 100000, 2, dtype=np.uint32)
53
- if isinstance(seed_or_key, int):
54
- key = jr.PRNGKey(seed_or_key)
55
- else:
56
- if len(seed_or_key) != 2 and seed_or_key.dtype != np.uint32:
57
- raise ValueError('key must be an array with dtype uint32. '
58
- f'But we got {seed_or_key}')
59
- key = seed_or_key
60
- super().__init__(key)
61
-
62
- def __repr__(self) -> str:
63
- print_code = repr(self.value)
64
- i = print_code.index('(')
65
- return f'{self.__class__.__name__}(key={print_code[i:]})'
66
-
67
- def _check_if_deleted(self):
68
- if isinstance(self._value, jax.Array) and not isinstance(self._value, jax.core.Tracer) and self._value.is_deleted():
69
- self.seed()
70
-
71
- # ------------------- #
72
- # seed and random key #
73
- # ------------------- #
74
-
75
- def clone(self):
76
- return type(self)(self.split_key())
77
-
78
- def seed(self, seed_or_key: Optional[SeedOrKey] = None):
79
- """Sets a new random seed.
80
-
81
- Parameters
82
- ----------
83
- seed_or_key: int, ArrayLike, optional
84
- It can be an integer for initial seed of the random number generator,
85
- or it can be a JAX's PRNKey, which is an array with two elements and `uint32` dtype.
86
- """
87
- with jax.ensure_compile_time_eval():
88
- if seed_or_key is None:
89
- seed_or_key = np.random.randint(0, 100000, 2, dtype=np.uint32)
90
- if np.size(seed_or_key) == 1:
91
- key = jr.PRNGKey(seed_or_key)
92
- else:
93
- if len(seed_or_key) != 2 and seed_or_key.dtype != np.uint32:
94
- raise ValueError('key must be an array with dtype uint32. '
95
- f'But we got {seed_or_key}')
96
- key = seed_or_key
97
- self.value = key
98
-
99
- def split_key(self):
100
- """Create a new seed from the current seed.
101
- """
102
- if not isinstance(self.value, jax.Array):
103
- self.value = jnp.asarray(self.value, dtype=jnp.uint32)
104
- keys = jr.split(self.value, num=2)
105
- self.value = keys[0]
106
- return keys[1]
107
-
108
- def split_keys(self, n: int):
109
- """Create multiple seeds from the current seed. This is used
110
- internally by `pmap` and `vmap` to ensure that random numbers
111
- are different in parallel threads.
112
-
113
- Parameters
114
- ----------
115
- n : int
116
- The number of seeds to generate.
117
- """
118
- keys = jr.split(self.value, n + 1)
119
- self.value = keys[0]
120
- return keys[1:]
121
-
122
- # ---------------- #
123
- # random functions #
124
- # ---------------- #
42
+ """RandomState that track the random generator state. """
43
+
44
+ # __slots__ = ('_backup', '_value')
45
+
46
+ def __init__(self, seed_or_key: Optional[SeedOrKey] = None):
47
+ """RandomState constructor.
48
+
49
+ Parameters
50
+ ----------
51
+ seed_or_key: int, Array, optional
52
+ It can be an integer for initial seed of the random number generator,
53
+ or it can be a JAX's PRNKey, which is an array with two elements and `uint32` dtype.
54
+ """
55
+ with jax.ensure_compile_time_eval():
56
+ if seed_or_key is None:
57
+ seed_or_key = np.random.randint(0, 100000, 2, dtype=np.uint32)
58
+ if isinstance(seed_or_key, int):
59
+ key = jr.PRNGKey(seed_or_key)
60
+ else:
61
+ if len(seed_or_key) != 2 and seed_or_key.dtype != np.uint32:
62
+ raise ValueError('key must be an array with dtype uint32. '
63
+ f'But we got {seed_or_key}')
64
+ key = seed_or_key
65
+ super().__init__(key)
66
+
67
+ self._backup = None
68
+
69
+ def __repr__(self):
70
+ return f'{self.__class__.__name__}({self.value})'
71
+
72
+ def check_if_deleted(self):
73
+ if (
74
+ isinstance(self._value, jax.Array) and
75
+ not isinstance(self._value, jax.core.Tracer) and
76
+ self._value.is_deleted()
77
+ ):
78
+ self.seed()
79
+
80
+ # ------------------- #
81
+ # seed and random key #
82
+ # ------------------- #
83
+
84
+ def backup_key(self):
85
+ if self._backup is not None:
86
+ raise ValueError('The random key has been backed up, and has not been restored.')
87
+ self._backup = self.value
88
+
89
+ def restore_key(self):
90
+ if self._backup is None:
91
+ raise ValueError('The random key has not been backed up.')
92
+ self.value = self._backup
93
+ self._backup = None
94
+
95
+ def clone(self):
96
+ return type(self)(self.split_key())
97
+
98
+ def set_key(self, key: SeedOrKey):
99
+ self.value = key
100
+
101
+ def seed(self, seed_or_key: Optional[SeedOrKey] = None):
102
+ """Sets a new random seed.
103
+
104
+ Parameters
105
+ ----------
106
+ seed_or_key: int, ArrayLike, optional
107
+ It can be an integer for initial seed of the random number generator,
108
+ or it can be a JAX's PRNKey, which is an array with two elements and `uint32` dtype.
109
+ """
110
+ with jax.ensure_compile_time_eval():
111
+ if seed_or_key is None:
112
+ seed_or_key = np.random.randint(0, 100000, 2, dtype=np.uint32)
113
+ if np.size(seed_or_key) == 1:
114
+ key = jr.PRNGKey(seed_or_key)
115
+ else:
116
+ if len(seed_or_key) != 2 and seed_or_key.dtype != np.uint32:
117
+ raise ValueError('key must be an array with dtype uint32. '
118
+ f'But we got {seed_or_key}')
119
+ key = seed_or_key
120
+ self.value = key
121
+
122
+ def split_key(self, n: Optional[int] = None, backup: bool = False) -> SeedOrKey:
123
+ """
124
+ Create a new seed from the current seed.
125
+
126
+ Parameters
127
+ ----------
128
+ n: int, optional
129
+ The number of seeds to generate.
130
+ backup : bool, optional
131
+ Whether to backup the current key.
132
+
133
+ Returns
134
+ -------
135
+ key : SeedOrKey
136
+ The new seed or a tuple of JAX random keys.
137
+ """
138
+ if n is not None:
139
+ assert isinstance(n, int) and n >= 1, f'n should be an integer greater than 1, but we got {n}'
140
+
141
+ if not isinstance(self.value, jax.Array):
142
+ self.value = jnp.asarray(self.value, dtype=jnp.uint32)
143
+ keys = jr.split(self.value, num=2 if n is None else n + 1)
144
+ self.value = keys[0]
145
+ if backup:
146
+ self.backup_key()
147
+ if n is None:
148
+ return keys[1]
149
+ else:
150
+ return keys[1:]
151
+
152
+ def self_assign_multi_keys(self, n: int, backup: bool = True):
153
+ """
154
+ Self-assign multiple keys to the current random state.
155
+ """
156
+ if backup:
157
+ keys = jr.split(self.value, n + 1)
158
+ self.value = keys[0]
159
+ self.backup_key()
160
+ self.value = keys[1:]
161
+ else:
162
+ self.value = jr.split(self.value, n)
163
+
164
+ # ---------------- #
165
+ # random functions #
166
+ # ---------------- #
167
+
168
+ def rand(self, *dn, key: Optional[SeedOrKey] = None, dtype: DTypeLike = None):
169
+ key = self.split_key() if key is None else _formalize_key(key)
170
+ dtype = dtype or environ.dftype()
171
+ r = uniform_for_unit(key, shape=dn, minval=0., maxval=1., dtype=dtype)
172
+ return r
173
+
174
+ def randint(
175
+ self,
176
+ low,
177
+ high=None,
178
+ size: Optional[Size] = None,
179
+ dtype: DTypeLike = None,
180
+ key: Optional[SeedOrKey] = None
181
+ ):
182
+ if high is None:
183
+ high = low
184
+ low = 0
185
+ high = _check_py_seq(high)
186
+ low = _check_py_seq(low)
187
+ if size is None:
188
+ size = lax.broadcast_shapes(jnp.shape(low),
189
+ jnp.shape(high))
190
+ key = self.split_key() if key is None else _formalize_key(key)
191
+ dtype = dtype or environ.ditype()
192
+ r = jr.randint(key,
193
+ shape=_size2shape(size),
194
+ minval=low, maxval=high, dtype=dtype)
195
+ return r
196
+
197
+ def random_integers(
198
+ self,
199
+ low,
200
+ high=None,
201
+ size: Optional[Size] = None,
202
+ key: Optional[SeedOrKey] = None,
203
+ dtype: DTypeLike = None,
204
+ ):
205
+ low = _check_py_seq(low)
206
+ high = _check_py_seq(high)
207
+ if high is None:
208
+ high = low
209
+ low = 1
210
+ high += 1
211
+ if size is None:
212
+ size = lax.broadcast_shapes(jnp.shape(low), jnp.shape(high))
213
+ key = self.split_key() if key is None else _formalize_key(key)
214
+ dtype = dtype or environ.ditype()
215
+ r = jr.randint(key,
216
+ shape=_size2shape(size),
217
+ minval=low,
218
+ maxval=high,
219
+ dtype=dtype)
220
+ return r
221
+
222
+ def randn(self, *dn, key: Optional[SeedOrKey] = None, dtype: DTypeLike = None):
223
+ key = self.split_key() if key is None else _formalize_key(key)
224
+ dtype = dtype or environ.dftype()
225
+ r = jr.normal(key, shape=dn, dtype=dtype)
226
+ return r
227
+
228
+ def random(self,
229
+ size: Optional[Size] = None,
230
+ key: Optional[SeedOrKey] = None,
231
+ dtype: DTypeLike = None):
232
+ dtype = dtype or environ.dftype()
233
+ key = self.split_key() if key is None else _formalize_key(key)
234
+ r = uniform_for_unit(key, shape=_size2shape(size), minval=0., maxval=1., dtype=dtype)
235
+ return r
125
236
 
126
- def rand(self, *dn, key: Optional[SeedOrKey] = None, dtype: DTypeLike = None):
127
- key = self.split_key() if key is None else _formalize_key(key)
128
- dtype = dtype or environ.dftype()
129
- r = uniform_for_unit(key, shape=dn, minval=0., maxval=1., dtype=dtype)
130
- return r
131
-
132
- def randint(
133
- self,
134
- low,
135
- high=None,
136
- size: Optional[Size] = None,
137
- dtype: DTypeLike = None,
138
- key: Optional[SeedOrKey] = None
139
- ):
140
- if high is None:
141
- high = low
142
- low = 0
143
- high = _check_py_seq(high)
144
- low = _check_py_seq(low)
145
- if size is None:
146
- size = lax.broadcast_shapes(jnp.shape(low),
147
- jnp.shape(high))
148
- key = self.split_key() if key is None else _formalize_key(key)
149
- dtype = dtype or environ.ditype()
150
- r = jr.randint(key,
151
- shape=_size2shape(size),
152
- minval=low, maxval=high, dtype=dtype)
153
- return r
154
-
155
- def random_integers(
156
- self,
157
- low,
158
- high=None,
159
- size: Optional[Size] = None,
160
- key: Optional[SeedOrKey] = None,
161
- dtype: DTypeLike = None,
162
- ):
163
- low = _check_py_seq(low)
164
- high = _check_py_seq(high)
165
- if high is None:
166
- high = low
167
- low = 1
168
- high += 1
169
- if size is None:
170
- size = lax.broadcast_shapes(jnp.shape(low), jnp.shape(high))
171
- key = self.split_key() if key is None else _formalize_key(key)
172
- dtype = dtype or environ.ditype()
173
- r = jr.randint(key,
174
- shape=_size2shape(size),
175
- minval=low,
176
- maxval=high,
177
- dtype=dtype)
178
- return r
179
-
180
- def randn(self, *dn, key: Optional[SeedOrKey] = None, dtype: DTypeLike = None):
181
- key = self.split_key() if key is None else _formalize_key(key)
182
- dtype = dtype or environ.dftype()
183
- r = jr.normal(key, shape=dn, dtype=dtype)
184
- return r
237
+ def random_sample(self,
238
+ size: Optional[Size] = None,
239
+ key: Optional[SeedOrKey] = None,
240
+ dtype: DTypeLike = None):
241
+ r = self.random(size=size, key=key, dtype=dtype)
242
+ return r
185
243
 
186
- def random(self,
244
+ def ranf(self,
187
245
  size: Optional[Size] = None,
188
246
  key: Optional[SeedOrKey] = None,
189
247
  dtype: DTypeLike = None):
190
- dtype = dtype or environ.dftype()
191
- key = self.split_key() if key is None else _formalize_key(key)
192
- r = uniform_for_unit(key, shape=_size2shape(size), minval=0., maxval=1., dtype=dtype)
193
- return r
248
+ r = self.random(size=size, key=key, dtype=dtype)
249
+ return r
194
250
 
195
- def random_sample(self,
196
- size: Optional[Size] = None,
197
- key: Optional[SeedOrKey] = None,
198
- dtype: DTypeLike = None):
199
- r = self.random(size=size, key=key, dtype=dtype)
200
- return r
201
-
202
- def ranf(self,
203
- size: Optional[Size] = None,
204
- key: Optional[SeedOrKey] = None,
205
- dtype: DTypeLike = None):
206
- r = self.random(size=size, key=key, dtype=dtype)
207
- return r
251
+ def sample(self,
252
+ size: Optional[Size] = None,
253
+ key: Optional[SeedOrKey] = None,
254
+ dtype: DTypeLike = None):
255
+ r = self.random(size=size, key=key, dtype=dtype)
256
+ return r
208
257
 
209
- def sample(self,
210
- size: Optional[Size] = None,
211
- key: Optional[SeedOrKey] = None,
212
- dtype: DTypeLike = None):
213
- r = self.random(size=size, key=key, dtype=dtype)
214
- return r
258
+ def choice(self,
259
+ a,
260
+ size: Optional[Size] = None,
261
+ replace=True,
262
+ p=None,
263
+ key: Optional[SeedOrKey] = None):
264
+ a = _check_py_seq(a)
265
+ p = _check_py_seq(p)
266
+ key = self.split_key() if key is None else _formalize_key(key)
267
+ r = jr.choice(key, a=a, shape=_size2shape(size), replace=replace, p=p)
268
+ return r
269
+
270
+ def permutation(self,
271
+ x,
272
+ axis: int = 0,
273
+ independent: bool = False,
274
+ key: Optional[SeedOrKey] = None):
275
+ x = _check_py_seq(x)
276
+ key = self.split_key() if key is None else _formalize_key(key)
277
+ r = permutation_for_unit(key, x, axis=axis, independent=independent)
278
+ return r
279
+
280
+ def shuffle(self,
281
+ x,
282
+ axis=0,
283
+ key: Optional[SeedOrKey] = None):
284
+ key = self.split_key() if key is None else _formalize_key(key)
285
+ x = permutation_for_unit(key, x, axis=axis)
286
+ return x
215
287
 
216
- def choice(self,
288
+ def beta(self,
217
289
  a,
218
- size: Optional[Size] = None,
219
- replace=True,
220
- p=None,
221
- key: Optional[SeedOrKey] = None):
222
- a = _check_py_seq(a)
223
- p = _check_py_seq(p)
224
- key = self.split_key() if key is None else _formalize_key(key)
225
- r = jr.choice(key, a=a, shape=_size2shape(size), replace=replace, p=p)
226
- return r
227
-
228
- def permutation(self,
229
- x,
230
- axis: int = 0,
231
- independent: bool = False,
232
- key: Optional[SeedOrKey] = None):
233
- x = _check_py_seq(x)
234
- key = self.split_key() if key is None else _formalize_key(key)
235
- r = permutation_for_unit(key, x, axis=axis, independent=independent)
236
- return r
237
-
238
- def shuffle(self,
239
- x,
240
- axis=0,
241
- key: Optional[SeedOrKey] = None):
242
- key = self.split_key() if key is None else _formalize_key(key)
243
- x = permutation_for_unit(key, x, axis=axis)
244
- return x
245
-
246
- def beta(self,
247
- a,
248
- b,
249
- size: Optional[Size] = None,
250
- key: Optional[SeedOrKey] = None,
251
- dtype: DTypeLike = None):
252
- a = _check_py_seq(a)
253
- b = _check_py_seq(b)
254
- if size is None:
255
- size = lax.broadcast_shapes(jnp.shape(a), jnp.shape(b))
256
- key = self.split_key() if key is None else _formalize_key(key)
257
- dtype = dtype or environ.dftype()
258
- r = jr.beta(key, a=a, b=b, shape=_size2shape(size), dtype=dtype)
259
- return r
260
-
261
- def exponential(self,
262
- scale=None,
263
- size: Optional[Size] = None,
264
- key: Optional[SeedOrKey] = None,
265
- dtype: DTypeLike = None):
266
- if size is None:
267
- size = jnp.shape(scale)
268
- key = self.split_key() if key is None else _formalize_key(key)
269
- dtype = dtype or environ.dftype()
270
- scale = jnp.asarray(scale, dtype=dtype)
271
- r = jr.exponential(key, shape=_size2shape(size), dtype=dtype)
272
- if scale is not None:
273
- r = r / scale
274
- return r
275
-
276
- def gamma(self,
277
- shape,
278
- scale=None,
279
- size: Optional[Size] = None,
280
- key: Optional[SeedOrKey] = None,
281
- dtype: DTypeLike = None):
282
- shape = _check_py_seq(shape)
283
- scale = _check_py_seq(scale)
284
- if size is None:
285
- size = lax.broadcast_shapes(jnp.shape(shape), jnp.shape(scale))
286
- key = self.split_key() if key is None else _formalize_key(key)
287
- dtype = dtype or environ.dftype()
288
- r = jr.gamma(key, a=shape, shape=_size2shape(size), dtype=dtype)
289
- if scale is not None:
290
- r = r * scale
291
- return r
292
-
293
- def gumbel(self,
294
- loc=None,
295
- scale=None,
290
+ b,
296
291
  size: Optional[Size] = None,
297
292
  key: Optional[SeedOrKey] = None,
298
293
  dtype: DTypeLike = None):
299
- loc = _check_py_seq(loc)
300
- scale = _check_py_seq(scale)
301
- if size is None:
302
- size = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale))
303
- key = self.split_key() if key is None else _formalize_key(key)
304
- dtype = dtype or environ.dftype()
305
- r = _loc_scale(loc, scale, jr.gumbel(key, shape=_size2shape(size), dtype=dtype))
306
- return r
307
-
308
- def laplace(self,
309
- loc=None,
294
+ a = _check_py_seq(a)
295
+ b = _check_py_seq(b)
296
+ if size is None:
297
+ size = lax.broadcast_shapes(jnp.shape(a), jnp.shape(b))
298
+ key = self.split_key() if key is None else _formalize_key(key)
299
+ dtype = dtype or environ.dftype()
300
+ r = jr.beta(key, a=a, b=b, shape=_size2shape(size), dtype=dtype)
301
+ return r
302
+
303
+ def exponential(self,
304
+ scale=None,
305
+ size: Optional[Size] = None,
306
+ key: Optional[SeedOrKey] = None,
307
+ dtype: DTypeLike = None):
308
+ if size is None:
309
+ size = jnp.shape(scale)
310
+ key = self.split_key() if key is None else _formalize_key(key)
311
+ dtype = dtype or environ.dftype()
312
+ scale = jnp.asarray(scale, dtype=dtype)
313
+ r = jr.exponential(key, shape=_size2shape(size), dtype=dtype)
314
+ if scale is not None:
315
+ r = r / scale
316
+ return r
317
+
318
+ def gamma(self,
319
+ shape,
310
320
  scale=None,
311
321
  size: Optional[Size] = None,
312
322
  key: Optional[SeedOrKey] = None,
313
323
  dtype: DTypeLike = None):
314
- loc = _check_py_seq(loc)
315
- scale = _check_py_seq(scale)
316
- if size is None:
317
- size = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale))
318
- key = self.split_key() if key is None else _formalize_key(key)
319
- dtype = dtype or environ.dftype()
320
- r = _loc_scale(loc, scale, jr.laplace(key, shape=_size2shape(size), dtype=dtype))
321
- return r
322
-
323
- def logistic(self,
324
+ shape = _check_py_seq(shape)
325
+ scale = _check_py_seq(scale)
326
+ if size is None:
327
+ size = lax.broadcast_shapes(jnp.shape(shape), jnp.shape(scale))
328
+ key = self.split_key() if key is None else _formalize_key(key)
329
+ dtype = dtype or environ.dftype()
330
+ r = jr.gamma(key, a=shape, shape=_size2shape(size), dtype=dtype)
331
+ if scale is not None:
332
+ r = r * scale
333
+ return r
334
+
335
+ def gumbel(self,
324
336
  loc=None,
325
337
  scale=None,
326
338
  size: Optional[Size] = None,
327
339
  key: Optional[SeedOrKey] = None,
328
340
  dtype: DTypeLike = None):
329
- loc = _check_py_seq(loc)
330
- scale = _check_py_seq(scale)
331
- if size is None:
332
- size = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale))
333
- key = self.split_key() if key is None else _formalize_key(key)
334
- dtype = dtype or environ.dftype()
335
- r = _loc_scale(loc, scale, jr.logistic(key, shape=_size2shape(size), dtype=dtype))
336
- return r
337
-
338
- def normal(self,
339
- loc=None,
340
- scale=None,
341
- size: Optional[Size] = None,
342
- key: Optional[SeedOrKey] = None,
343
- dtype: DTypeLike = None):
344
- loc = _check_py_seq(loc)
345
- scale = _check_py_seq(scale)
346
- if size is None:
347
- size = lax.broadcast_shapes(jnp.shape(scale), jnp.shape(loc))
348
- key = self.split_key() if key is None else _formalize_key(key)
349
- dtype = dtype or environ.dftype()
350
- r = _loc_scale(loc, scale, jr.normal(key, shape=_size2shape(size), dtype=dtype))
351
- return r
352
-
353
- def pareto(self,
354
- a,
355
- size: Optional[Size] = None,
356
- key: Optional[SeedOrKey] = None,
357
- dtype: DTypeLike = None):
358
- if size is None:
359
- size = jnp.shape(a)
360
- key = self.split_key() if key is None else _formalize_key(key)
361
- dtype = dtype or environ.dftype()
362
- a = jnp.asarray(a, dtype=dtype)
363
- r = jr.pareto(key, b=a, shape=_size2shape(size), dtype=dtype)
364
- return r
365
-
366
- def poisson(self,
367
- lam=1.0,
368
- size: Optional[Size] = None,
369
- key: Optional[SeedOrKey] = None,
370
- dtype: DTypeLike = None):
371
- lam = _check_py_seq(lam)
372
- if size is None:
373
- size = jnp.shape(lam)
374
- key = self.split_key() if key is None else _formalize_key(key)
375
- dtype = dtype or environ.ditype()
376
- r = jr.poisson(key, lam=lam, shape=_size2shape(size), dtype=dtype)
377
- return r
378
-
379
- def standard_cauchy(self,
380
- size: Optional[Size] = None,
381
- key: Optional[SeedOrKey] = None,
382
- dtype: DTypeLike = None):
383
- key = self.split_key() if key is None else _formalize_key(key)
384
- dtype = dtype or environ.dftype()
385
- r = jr.cauchy(key, shape=_size2shape(size), dtype=dtype)
386
- return r
387
-
388
- def standard_exponential(self,
389
- size: Optional[Size] = None,
390
- key: Optional[SeedOrKey] = None,
391
- dtype: DTypeLike = None):
392
- key = self.split_key() if key is None else _formalize_key(key)
393
- dtype = dtype or environ.dftype()
394
- r = jr.exponential(key, shape=_size2shape(size), dtype=dtype)
395
- return r
396
-
397
- def standard_gamma(self,
398
- shape,
399
- size: Optional[Size] = None,
400
- key: Optional[SeedOrKey] = None,
401
- dtype: DTypeLike = None):
402
- shape = _check_py_seq(shape)
403
- if size is None:
404
- size = jnp.shape(shape)
405
- key = self.split_key() if key is None else _formalize_key(key)
406
- dtype = dtype or environ.dftype()
407
- r = jr.gamma(key, a=shape, shape=_size2shape(size), dtype=dtype)
408
- return r
409
-
410
- def standard_normal(self,
411
- size: Optional[Size] = None,
412
- key: Optional[SeedOrKey] = None,
413
- dtype: DTypeLike = None):
414
- key = self.split_key() if key is None else _formalize_key(key)
415
- dtype = dtype or environ.dftype()
416
- r = jr.normal(key, shape=_size2shape(size), dtype=dtype)
417
- return r
418
-
419
- def standard_t(self, df,
341
+ loc = _check_py_seq(loc)
342
+ scale = _check_py_seq(scale)
343
+ if size is None:
344
+ size = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale))
345
+ key = self.split_key() if key is None else _formalize_key(key)
346
+ dtype = dtype or environ.dftype()
347
+ r = _loc_scale(loc, scale, jr.gumbel(key, shape=_size2shape(size), dtype=dtype))
348
+ return r
349
+
350
+ def laplace(self,
351
+ loc=None,
352
+ scale=None,
353
+ size: Optional[Size] = None,
354
+ key: Optional[SeedOrKey] = None,
355
+ dtype: DTypeLike = None):
356
+ loc = _check_py_seq(loc)
357
+ scale = _check_py_seq(scale)
358
+ if size is None:
359
+ size = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale))
360
+ key = self.split_key() if key is None else _formalize_key(key)
361
+ dtype = dtype or environ.dftype()
362
+ r = _loc_scale(loc, scale, jr.laplace(key, shape=_size2shape(size), dtype=dtype))
363
+ return r
364
+
365
+ def logistic(self,
366
+ loc=None,
367
+ scale=None,
420
368
  size: Optional[Size] = None,
421
369
  key: Optional[SeedOrKey] = None,
422
370
  dtype: DTypeLike = None):
423
- df = _check_py_seq(df)
424
- if size is None:
425
- size = jnp.shape(size)
426
- key = self.split_key() if key is None else _formalize_key(key)
427
- dtype = dtype or environ.dftype()
428
- r = jr.t(key, df=df, shape=_size2shape(size), dtype=dtype)
429
- return r
430
-
431
- def uniform(self,
432
- low=0.0,
433
- high=1.0,
434
- size: Optional[Size] = None,
435
- key: Optional[SeedOrKey] = None,
436
- dtype: DTypeLike = None):
437
- low = _check_py_seq(low)
438
- high = _check_py_seq(high)
439
- if size is None:
440
- size = lax.broadcast_shapes(jnp.shape(low), jnp.shape(high))
441
- key = self.split_key() if key is None else _formalize_key(key)
442
- dtype = dtype or environ.dftype()
443
- r = uniform_for_unit(key, shape=_size2shape(size), minval=low, maxval=high, dtype=dtype)
444
- return r
445
-
446
- def __norm_cdf(self, x, sqrt2, dtype):
447
- # Computes standard normal cumulative distribution function
448
- return (np.asarray(1., dtype) + lax.erf(x / sqrt2)) / np.asarray(2., dtype)
449
-
450
- def truncated_normal(
451
- self,
452
- lower,
453
- upper,
454
- size: Optional[Size] = None,
455
- loc=0.,
456
- scale=1.,
457
- key: Optional[SeedOrKey] = None,
458
- dtype: DTypeLike = None
459
- ):
460
- lower = _check_py_seq(lower)
461
- upper = _check_py_seq(upper)
462
- loc = _check_py_seq(loc)
463
- scale = _check_py_seq(scale)
464
- dtype = dtype or environ.dftype()
465
-
466
- lower = u.math.asarray(lower, dtype=dtype)
467
- upper = u.math.asarray(upper, dtype=dtype)
468
- loc = u.math.asarray(loc, dtype=dtype)
469
- scale = u.math.asarray(scale, dtype=dtype)
470
- unit = u.get_unit(lower)
471
- lower, upper, loc, scale = (
472
- lower.mantissa if isinstance(lower, u.Quantity) else lower,
473
- u.Quantity(upper).in_unit(unit).mantissa,
474
- u.Quantity(loc).in_unit(unit).mantissa,
475
- u.Quantity(scale).in_unit(unit).mantissa
476
- )
477
-
478
- jit_error_if(
479
- u.math.any(u.math.logical_or(loc < lower - 2 * scale, loc > upper + 2 * scale)),
480
- "mean is more than 2 std from [lower, upper] in truncated_normal. "
481
- "The distribution of values may be incorrect."
482
- )
483
-
484
- if size is None:
485
- size = u.math.broadcast_shapes(jnp.shape(lower),
486
- jnp.shape(upper),
487
- jnp.shape(loc),
488
- jnp.shape(scale))
489
-
490
- # Values are generated by using a truncated uniform distribution and
491
- # then using the inverse CDF for the normal distribution.
492
- # Get upper and lower cdf values
493
- sqrt2 = np.array(np.sqrt(2), dtype=dtype)
494
- l = self.__norm_cdf((lower - loc) / scale, sqrt2, dtype)
495
- u = self.__norm_cdf((upper - loc) / scale, sqrt2, dtype)
496
-
497
- # Uniformly fill tensor with values from [l, u], then translate to
498
- # [2l-1, 2u-1].
499
- key = self.split_key() if key is None else _formalize_key(key)
500
- out = uniform_for_unit(
501
- key, size, dtype,
502
- minval=lax.nextafter(2 * l - 1, np.array(np.inf, dtype=dtype)),
503
- maxval=lax.nextafter(2 * u - 1, np.array(-np.inf, dtype=dtype))
504
- )
505
-
506
- # Use inverse cdf transform for normal distribution to get truncated
507
- # standard normal
508
- out = lax.erf_inv(out)
509
-
510
- # Transform to proper mean, std
511
- out = out * scale * sqrt2 + loc
512
-
513
- # Clamp to ensure it's in the proper range
514
- out = jnp.clip(
515
- out,
516
- lax.nextafter(lax.stop_gradient(lower), np.array(np.inf, dtype=dtype)),
517
- lax.nextafter(lax.stop_gradient(upper), np.array(-np.inf, dtype=dtype))
518
- )
519
- return out if unit.is_unitless else u.Quantity(out, unit=unit)
520
-
521
- def _check_p(self, p):
522
- raise ValueError(f'Parameter p should be within [0, 1], but we got {p}')
523
-
524
- def bernoulli(self,
525
- p,
526
- size: Optional[Size] = None,
527
- key: Optional[SeedOrKey] = None):
528
- p = _check_py_seq(p)
529
- jit_error_if(jnp.any(jnp.logical_and(p < 0, p > 1)), self._check_p, p)
530
- if size is None:
531
- size = jnp.shape(p)
532
- key = self.split_key() if key is None else _formalize_key(key)
533
- r = jr.bernoulli(key, p=p, shape=_size2shape(size))
534
- return r
535
-
536
- def lognormal(
537
- self,
538
- mean=None,
539
- sigma=None,
540
- size: Optional[Size] = None,
541
- key: Optional[SeedOrKey] = None,
542
- dtype: DTypeLike = None
543
- ):
544
- mean = _check_py_seq(mean)
545
- sigma = _check_py_seq(sigma)
546
- mean = u.math.asarray(mean, dtype=dtype)
547
- sigma = u.math.asarray(sigma, dtype=dtype)
548
- unit = mean.unit if isinstance(mean, u.Quantity) else u.Unit()
549
- mean = mean.mantissa if isinstance(mean, u.Quantity) else mean
550
- sigma = sigma.in_unit(unit).mantissa if isinstance(sigma, u.Quantity) else sigma
551
-
552
- if size is None:
553
- size = jnp.broadcast_shapes(
554
- jnp.shape(mean),
555
- jnp.shape(sigma)
556
- )
557
- key = self.split_key() if key is None else _formalize_key(key)
558
- dtype = dtype or environ.dftype()
559
- samples = jr.normal(key, shape=_size2shape(size), dtype=dtype)
560
- samples = _loc_scale(mean, sigma, samples)
561
- samples = jnp.exp(samples)
562
- return samples if unit.is_unitless else u.Quantity(samples, unit=unit)
563
-
564
- def binomial(self,
565
- n,
566
- p,
371
+ loc = _check_py_seq(loc)
372
+ scale = _check_py_seq(scale)
373
+ if size is None:
374
+ size = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale))
375
+ key = self.split_key() if key is None else _formalize_key(key)
376
+ dtype = dtype or environ.dftype()
377
+ r = _loc_scale(loc, scale, jr.logistic(key, shape=_size2shape(size), dtype=dtype))
378
+ return r
379
+
380
+ def normal(self,
381
+ loc=None,
382
+ scale=None,
567
383
  size: Optional[Size] = None,
568
384
  key: Optional[SeedOrKey] = None,
569
385
  dtype: DTypeLike = None):
570
- n = _check_py_seq(n)
571
- p = _check_py_seq(p)
572
- jit_error_if(jnp.any(jnp.logical_and(p < 0, p > 1)), self._check_p, p)
573
- if size is None:
574
- size = jnp.broadcast_shapes(jnp.shape(n), jnp.shape(p))
575
- key = self.split_key() if key is None else _formalize_key(key)
576
- r = _binomial(key, p, n, shape=_size2shape(size))
577
- dtype = dtype or environ.ditype()
578
- return jnp.asarray(r, dtype=dtype)
579
-
580
- def chisquare(self,
581
- df,
582
- size: Optional[Size] = None,
583
- key: Optional[SeedOrKey] = None,
584
- dtype: DTypeLike = None):
585
- df = _check_py_seq(df)
586
- key = self.split_key() if key is None else _formalize_key(key)
587
- dtype = dtype or environ.dftype()
588
- if size is None:
589
- if jnp.ndim(df) == 0:
590
- dist = jr.normal(key, (df,), dtype=dtype) ** 2
591
- dist = dist.sum()
592
- else:
593
- raise NotImplementedError('Do not support non-scale "df" when "size" is None')
594
- else:
595
- dist = jr.normal(key, (df,) + _size2shape(size), dtype=dtype) ** 2
596
- dist = dist.sum(axis=0)
597
- return dist
598
-
599
- def dirichlet(self,
600
- alpha,
386
+ loc = _check_py_seq(loc)
387
+ scale = _check_py_seq(scale)
388
+ if size is None:
389
+ size = lax.broadcast_shapes(jnp.shape(scale), jnp.shape(loc))
390
+ key = self.split_key() if key is None else _formalize_key(key)
391
+ dtype = dtype or environ.dftype()
392
+ r = _loc_scale(loc, scale, jr.normal(key, shape=_size2shape(size), dtype=dtype))
393
+ return r
394
+
395
+ def pareto(self,
396
+ a,
397
+ size: Optional[Size] = None,
398
+ key: Optional[SeedOrKey] = None,
399
+ dtype: DTypeLike = None):
400
+ if size is None:
401
+ size = jnp.shape(a)
402
+ key = self.split_key() if key is None else _formalize_key(key)
403
+ dtype = dtype or environ.dftype()
404
+ a = jnp.asarray(a, dtype=dtype)
405
+ r = jr.pareto(key, b=a, shape=_size2shape(size), dtype=dtype)
406
+ return r
407
+
408
+ def poisson(self,
409
+ lam=1.0,
601
410
  size: Optional[Size] = None,
602
411
  key: Optional[SeedOrKey] = None,
603
412
  dtype: DTypeLike = None):
604
- key = self.split_key() if key is None else _formalize_key(key)
605
- alpha = _check_py_seq(alpha)
606
- dtype = dtype or environ.dftype()
607
- r = jr.dirichlet(key, alpha=alpha, shape=_size2shape(size), dtype=dtype)
608
- return r
413
+ lam = _check_py_seq(lam)
414
+ if size is None:
415
+ size = jnp.shape(lam)
416
+ key = self.split_key() if key is None else _formalize_key(key)
417
+ dtype = dtype or environ.ditype()
418
+ r = jr.poisson(key, lam=lam, shape=_size2shape(size), dtype=dtype)
419
+ return r
420
+
421
+ def standard_cauchy(self,
422
+ size: Optional[Size] = None,
423
+ key: Optional[SeedOrKey] = None,
424
+ dtype: DTypeLike = None):
425
+ key = self.split_key() if key is None else _formalize_key(key)
426
+ dtype = dtype or environ.dftype()
427
+ r = jr.cauchy(key, shape=_size2shape(size), dtype=dtype)
428
+ return r
429
+
430
+ def standard_exponential(self,
431
+ size: Optional[Size] = None,
432
+ key: Optional[SeedOrKey] = None,
433
+ dtype: DTypeLike = None):
434
+ key = self.split_key() if key is None else _formalize_key(key)
435
+ dtype = dtype or environ.dftype()
436
+ r = jr.exponential(key, shape=_size2shape(size), dtype=dtype)
437
+ return r
438
+
439
+ def standard_gamma(self,
440
+ shape,
441
+ size: Optional[Size] = None,
442
+ key: Optional[SeedOrKey] = None,
443
+ dtype: DTypeLike = None):
444
+ shape = _check_py_seq(shape)
445
+ if size is None:
446
+ size = jnp.shape(shape)
447
+ key = self.split_key() if key is None else _formalize_key(key)
448
+ dtype = dtype or environ.dftype()
449
+ r = jr.gamma(key, a=shape, shape=_size2shape(size), dtype=dtype)
450
+ return r
451
+
452
+ def standard_normal(self,
453
+ size: Optional[Size] = None,
454
+ key: Optional[SeedOrKey] = None,
455
+ dtype: DTypeLike = None):
456
+ key = self.split_key() if key is None else _formalize_key(key)
457
+ dtype = dtype or environ.dftype()
458
+ r = jr.normal(key, shape=_size2shape(size), dtype=dtype)
459
+ return r
609
460
 
610
- def geometric(self,
611
- p,
461
+ def standard_t(self, df,
462
+ size: Optional[Size] = None,
463
+ key: Optional[SeedOrKey] = None,
464
+ dtype: DTypeLike = None):
465
+ df = _check_py_seq(df)
466
+ if size is None:
467
+ size = jnp.shape(size)
468
+ key = self.split_key() if key is None else _formalize_key(key)
469
+ dtype = dtype or environ.dftype()
470
+ r = jr.t(key, df=df, shape=_size2shape(size), dtype=dtype)
471
+ return r
472
+
473
+ def uniform(self,
474
+ low=0.0,
475
+ high=1.0,
612
476
  size: Optional[Size] = None,
613
477
  key: Optional[SeedOrKey] = None,
614
478
  dtype: DTypeLike = None):
615
- p = _check_py_seq(p)
616
- if size is None:
617
- size = jnp.shape(p)
618
- key = self.split_key() if key is None else _formalize_key(key)
619
- dtype = dtype or environ.dftype()
620
- u = uniform_for_unit(key, size, dtype=dtype)
621
- r = jnp.floor(jnp.log1p(-u) / jnp.log1p(-p))
622
- return r
623
-
624
- def _check_p2(self, p):
625
- raise ValueError(f'We require `sum(pvals[:-1]) <= 1`. But we got {p}')
626
-
627
- def multinomial(self,
628
- n,
629
- pvals,
479
+ low = _check_py_seq(low)
480
+ high = _check_py_seq(high)
481
+ if size is None:
482
+ size = lax.broadcast_shapes(jnp.shape(low), jnp.shape(high))
483
+ key = self.split_key() if key is None else _formalize_key(key)
484
+ dtype = dtype or environ.dftype()
485
+ r = uniform_for_unit(key, shape=_size2shape(size), minval=low, maxval=high, dtype=dtype)
486
+ return r
487
+
488
+ def __norm_cdf(self, x, sqrt2, dtype):
489
+ # Computes standard normal cumulative distribution function
490
+ return (np.asarray(1., dtype) + lax.erf(x / sqrt2)) / np.asarray(2., dtype)
491
+
492
+ def truncated_normal(
493
+ self,
494
+ lower,
495
+ upper,
496
+ size: Optional[Size] = None,
497
+ loc=0.,
498
+ scale=1.,
499
+ key: Optional[SeedOrKey] = None,
500
+ dtype: DTypeLike = None
501
+ ):
502
+ lower = _check_py_seq(lower)
503
+ upper = _check_py_seq(upper)
504
+ loc = _check_py_seq(loc)
505
+ scale = _check_py_seq(scale)
506
+ dtype = dtype or environ.dftype()
507
+
508
+ lower = u.math.asarray(lower, dtype=dtype)
509
+ upper = u.math.asarray(upper, dtype=dtype)
510
+ loc = u.math.asarray(loc, dtype=dtype)
511
+ scale = u.math.asarray(scale, dtype=dtype)
512
+ unit = u.get_unit(lower)
513
+ lower, upper, loc, scale = (
514
+ lower.mantissa if isinstance(lower, u.Quantity) else lower,
515
+ u.Quantity(upper).in_unit(unit).mantissa,
516
+ u.Quantity(loc).in_unit(unit).mantissa,
517
+ u.Quantity(scale).in_unit(unit).mantissa
518
+ )
519
+
520
+ jit_error_if(
521
+ u.math.any(u.math.logical_or(loc < lower - 2 * scale, loc > upper + 2 * scale)),
522
+ "mean is more than 2 std from [lower, upper] in truncated_normal. "
523
+ "The distribution of values may be incorrect."
524
+ )
525
+
526
+ if size is None:
527
+ size = u.math.broadcast_shapes(jnp.shape(lower),
528
+ jnp.shape(upper),
529
+ jnp.shape(loc),
530
+ jnp.shape(scale))
531
+
532
+ # Values are generated by using a truncated uniform distribution and
533
+ # then using the inverse CDF for the normal distribution.
534
+ # Get upper and lower cdf values
535
+ sqrt2 = np.array(np.sqrt(2), dtype=dtype)
536
+ l = self.__norm_cdf((lower - loc) / scale, sqrt2, dtype)
537
+ u_ = self.__norm_cdf((upper - loc) / scale, sqrt2, dtype)
538
+
539
+ # Uniformly fill tensor with values from [l, u], then translate to
540
+ # [2l-1, 2u-1].
541
+ key = self.split_key() if key is None else _formalize_key(key)
542
+ out = uniform_for_unit(
543
+ key, size, dtype,
544
+ minval=lax.nextafter(2 * l - 1, np.array(np.inf, dtype=dtype)),
545
+ maxval=lax.nextafter(2 * u_ - 1, np.array(-np.inf, dtype=dtype))
546
+ )
547
+
548
+ # Use inverse cdf transform for normal distribution to get truncated
549
+ # standard normal
550
+ out = lax.erf_inv(out)
551
+
552
+ # Transform to proper mean, std
553
+ out = out * scale * sqrt2 + loc
554
+
555
+ # Clamp to ensure it's in the proper range
556
+ out = jnp.clip(
557
+ out,
558
+ lax.nextafter(lax.stop_gradient(lower), np.array(np.inf, dtype=dtype)),
559
+ lax.nextafter(lax.stop_gradient(upper), np.array(-np.inf, dtype=dtype))
560
+ )
561
+ return out if unit.is_unitless else u.Quantity(out, unit=unit)
562
+
563
+ def _check_p(self, p):
564
+ raise ValueError(f'Parameter p should be within [0, 1], but we got {p}')
565
+
566
+ def bernoulli(self,
567
+ p,
568
+ size: Optional[Size] = None,
569
+ key: Optional[SeedOrKey] = None):
570
+ p = _check_py_seq(p)
571
+ jit_error_if(jnp.any(jnp.logical_and(p < 0, p > 1)), self._check_p, p)
572
+ if size is None:
573
+ size = jnp.shape(p)
574
+ key = self.split_key() if key is None else _formalize_key(key)
575
+ r = jr.bernoulli(key, p=p, shape=_size2shape(size))
576
+ return r
577
+
578
+ def lognormal(
579
+ self,
580
+ mean=None,
581
+ sigma=None,
582
+ size: Optional[Size] = None,
583
+ key: Optional[SeedOrKey] = None,
584
+ dtype: DTypeLike = None
585
+ ):
586
+ mean = _check_py_seq(mean)
587
+ sigma = _check_py_seq(sigma)
588
+ mean = u.math.asarray(mean, dtype=dtype)
589
+ sigma = u.math.asarray(sigma, dtype=dtype)
590
+ unit = mean.unit if isinstance(mean, u.Quantity) else u.Unit()
591
+ mean = mean.mantissa if isinstance(mean, u.Quantity) else mean
592
+ sigma = sigma.in_unit(unit).mantissa if isinstance(sigma, u.Quantity) else sigma
593
+
594
+ if size is None:
595
+ size = jnp.broadcast_shapes(
596
+ jnp.shape(mean),
597
+ jnp.shape(sigma)
598
+ )
599
+ key = self.split_key() if key is None else _formalize_key(key)
600
+ dtype = dtype or environ.dftype()
601
+ samples = jr.normal(key, shape=_size2shape(size), dtype=dtype)
602
+ samples = _loc_scale(mean, sigma, samples)
603
+ samples = jnp.exp(samples)
604
+ return samples if unit.is_unitless else u.Quantity(samples, unit=unit)
605
+
606
+ def binomial(self,
607
+ n,
608
+ p,
609
+ size: Optional[Size] = None,
610
+ key: Optional[SeedOrKey] = None,
611
+ dtype: DTypeLike = None):
612
+ n = _check_py_seq(n)
613
+ p = _check_py_seq(p)
614
+ jit_error_if(jnp.any(jnp.logical_and(p < 0, p > 1)), self._check_p, p)
615
+ if size is None:
616
+ size = jnp.broadcast_shapes(jnp.shape(n), jnp.shape(p))
617
+ key = self.split_key() if key is None else _formalize_key(key)
618
+ r = _binomial(key, p, n, shape=_size2shape(size))
619
+ dtype = dtype or environ.ditype()
620
+ return jnp.asarray(r, dtype=dtype)
621
+
622
+ def chisquare(self,
623
+ df,
630
624
  size: Optional[Size] = None,
631
625
  key: Optional[SeedOrKey] = None,
632
626
  dtype: DTypeLike = None):
633
- key = self.split_key() if key is None else _formalize_key(key)
634
- n = _check_py_seq(n)
635
- pvals = _check_py_seq(pvals)
636
- jit_error_if(jnp.sum(pvals[:-1]) > 1., self._check_p2, pvals)
637
- if isinstance(n, jax.core.Tracer):
638
- raise ValueError("The total count parameter `n` should not be a jax abstract array.")
639
- size = _size2shape(size)
640
- n_max = int(np.max(jax.device_get(n)))
641
- batch_shape = lax.broadcast_shapes(jnp.shape(pvals)[:-1], jnp.shape(n))
642
- r = _multinomial(key, pvals, n, n_max, batch_shape + size)
643
- dtype = dtype or environ.ditype()
644
- return jnp.asarray(r, dtype=dtype)
645
-
646
- def multivariate_normal(
647
- self,
648
- mean,
649
- cov,
650
- size: Optional[Size] = None,
651
- method: str = 'cholesky',
652
- key: Optional[SeedOrKey] = None,
653
- dtype: DTypeLike = None
654
- ):
655
- if method not in {'svd', 'eigh', 'cholesky'}:
656
- raise ValueError("method must be one of {'svd', 'eigh', 'cholesky'}")
657
- dtype = dtype or environ.dftype()
658
- mean = u.math.asarray(_check_py_seq(mean), dtype=dtype)
659
- cov = u.math.asarray(_check_py_seq(cov), dtype=dtype)
660
- if isinstance(mean, u.Quantity):
661
- assert isinstance(cov, u.Quantity)
662
- assert mean.unit ** 2 == cov.unit
663
- mean = mean.mantissa if isinstance(mean, u.Quantity) else mean
664
- cov = cov.mantissa if isinstance(cov, u.Quantity) else cov
665
- unit = mean.unit if isinstance(mean, u.Quantity) else u.Unit()
666
-
667
- key = self.split_key() if key is None else _formalize_key(key)
668
- if not jnp.ndim(mean) >= 1:
669
- raise ValueError(f"multivariate_normal requires mean.ndim >= 1, got mean.ndim == {jnp.ndim(mean)}")
670
- if not jnp.ndim(cov) >= 2:
671
- raise ValueError(f"multivariate_normal requires cov.ndim >= 2, got cov.ndim == {jnp.ndim(cov)}")
672
- n = mean.shape[-1]
673
- if jnp.shape(cov)[-2:] != (n, n):
674
- raise ValueError(f"multivariate_normal requires cov.shape == (..., n, n) for n={n}, "
675
- f"but got cov.shape == {jnp.shape(cov)}.")
676
- if size is None:
677
- size = lax.broadcast_shapes(mean.shape[:-1], cov.shape[:-2])
678
- else:
679
- size = _size2shape(size)
680
- _check_shape("normal", size, mean.shape[:-1], cov.shape[:-2])
681
-
682
- if method == 'svd':
683
- (u, s, _) = jnp.linalg.svd(cov)
684
- factor = u * jnp.sqrt(s[..., None, :])
685
- elif method == 'eigh':
686
- (w, v) = jnp.linalg.eigh(cov)
687
- factor = v * jnp.sqrt(w[..., None, :])
688
- else: # 'cholesky'
689
- factor = jnp.linalg.cholesky(cov)
690
- normal_samples = jr.normal(key, size + mean.shape[-1:], dtype=dtype)
691
- r = mean + jnp.einsum('...ij,...j->...i', factor, normal_samples)
692
- return r if unit.is_unitless else u.Quantity(r, unit=unit)
693
-
694
- def rayleigh(self,
695
- scale=1.0,
696
- size: Optional[Size] = None,
697
- key: Optional[SeedOrKey] = None,
698
- dtype: DTypeLike = None):
699
- scale = _check_py_seq(scale)
700
- if size is None:
701
- size = jnp.shape(scale)
702
- key = self.split_key() if key is None else _formalize_key(key)
703
- dtype = dtype or environ.dftype()
704
- x = jnp.sqrt(-2. * jnp.log(uniform_for_unit(key, shape=_size2shape(size), minval=0, maxval=1, dtype=dtype)))
705
- r = x * scale
706
- return r
707
-
708
- def triangular(self,
709
- size: Optional[Size] = None,
710
- key: Optional[SeedOrKey] = None):
711
- key = self.split_key() if key is None else _formalize_key(key)
712
- bernoulli_samples = jr.bernoulli(key, p=0.5, shape=_size2shape(size))
713
- r = 2 * bernoulli_samples - 1
714
- return r
715
-
716
- def vonmises(self,
717
- mu,
718
- kappa,
719
- size: Optional[Size] = None,
720
- key: Optional[SeedOrKey] = None,
721
- dtype: DTypeLike = None):
722
- key = self.split_key() if key is None else _formalize_key(key)
723
- dtype = dtype or environ.dftype()
724
- mu = jnp.asarray(_check_py_seq(mu), dtype=dtype)
725
- kappa = jnp.asarray(_check_py_seq(kappa), dtype=dtype)
726
- if size is None:
727
- size = lax.broadcast_shapes(jnp.shape(mu), jnp.shape(kappa))
728
- size = _size2shape(size)
729
- samples = _von_mises_centered(key, kappa, size, dtype=dtype)
730
- samples = samples + mu
731
- samples = (samples + jnp.pi) % (2.0 * jnp.pi) - jnp.pi
732
- return samples
733
-
734
- def weibull(self,
735
- a,
736
- size: Optional[Size] = None,
737
- key: Optional[SeedOrKey] = None,
738
- dtype: DTypeLike = None):
739
- key = self.split_key() if key is None else _formalize_key(key)
740
- a = _check_py_seq(a)
741
- if size is None:
742
- size = jnp.shape(a)
743
- else:
744
- if jnp.size(a) > 1:
745
- raise ValueError(f'"a" should be a scalar when "size" is provided. But we got {a}')
746
- size = _size2shape(size)
747
- dtype = dtype or environ.dftype()
748
- random_uniform = uniform_for_unit(key=key, shape=size, minval=0, maxval=1, dtype=dtype)
749
- r = jnp.power(-jnp.log1p(-random_uniform), 1.0 / a)
750
- return r
751
-
752
- def weibull_min(self,
753
- a,
754
- scale=None,
627
+ df = _check_py_seq(df)
628
+ key = self.split_key() if key is None else _formalize_key(key)
629
+ dtype = dtype or environ.dftype()
630
+ if size is None:
631
+ if jnp.ndim(df) == 0:
632
+ dist = jr.normal(key, (df,), dtype=dtype) ** 2
633
+ dist = dist.sum()
634
+ else:
635
+ raise NotImplementedError('Do not support non-scale "df" when "size" is None')
636
+ else:
637
+ dist = jr.normal(key, (df,) + _size2shape(size), dtype=dtype) ** 2
638
+ dist = dist.sum(axis=0)
639
+ return dist
640
+
641
+ def dirichlet(self,
642
+ alpha,
755
643
  size: Optional[Size] = None,
756
644
  key: Optional[SeedOrKey] = None,
757
645
  dtype: DTypeLike = None):
758
- key = self.split_key() if key is None else _formalize_key(key)
759
- a = _check_py_seq(a)
760
- scale = _check_py_seq(scale)
761
- if size is None:
762
- size = jnp.broadcast_shapes(jnp.shape(a), jnp.shape(scale))
763
- else:
764
- if jnp.size(a) > 1:
765
- raise ValueError(f'"a" should be a scalar when "size" is provided. But we got {a}')
766
- size = _size2shape(size)
767
- dtype = dtype or environ.dftype()
768
- random_uniform = uniform_for_unit(key=key, shape=size, minval=0, maxval=1, dtype=dtype)
769
- r = jnp.power(-jnp.log1p(-random_uniform), 1.0 / a)
770
- if scale is not None:
771
- r /= scale
772
- return r
773
-
774
- def maxwell(self,
775
- size: Optional[Size] = None,
776
- key: Optional[SeedOrKey] = None,
777
- dtype: DTypeLike = None):
778
- key = self.split_key() if key is None else _formalize_key(key)
779
- shape = _size2shape(size) + (3,)
780
- dtype = dtype or environ.dftype()
781
- norm_rvs = jr.normal(key=key, shape=shape, dtype=dtype)
782
- r = jnp.linalg.norm(norm_rvs, axis=-1)
783
- return r
784
-
785
- def negative_binomial(self,
786
- n,
787
- p,
788
- size: Optional[Size] = None,
789
- key: Optional[SeedOrKey] = None,
790
- dtype: DTypeLike = None):
791
- n = _check_py_seq(n)
792
- p = _check_py_seq(p)
793
- if size is None:
794
- size = lax.broadcast_shapes(jnp.shape(n), jnp.shape(p))
795
- size = _size2shape(size)
796
- logits = jnp.log(p) - jnp.log1p(-p)
797
- if key is None:
798
- keys = self.split_keys(2)
799
- else:
800
- keys = jr.split(_formalize_key(key), 2)
801
- rate = self.gamma(shape=n, scale=jnp.exp(-logits), size=size, key=keys[0], dtype=environ.dftype())
802
- r = self.poisson(lam=rate, key=keys[1], dtype=dtype or environ.ditype())
803
- return r
804
-
805
- def wald(self,
806
- mean,
807
- scale,
808
- size: Optional[Size] = None,
809
- key: Optional[SeedOrKey] = None,
810
- dtype: DTypeLike = None):
811
- dtype = dtype or environ.dftype()
812
- key = self.split_key() if key is None else _formalize_key(key)
813
- mean = jnp.asarray(_check_py_seq(mean), dtype=dtype)
814
- scale = jnp.asarray(_check_py_seq(scale), dtype=dtype)
815
- if size is None:
816
- size = lax.broadcast_shapes(jnp.shape(mean), jnp.shape(scale))
817
- size = _size2shape(size)
818
- sampled_chi2 = jnp.square(self.randn(*size))
819
- sampled_uniform = self.uniform(size=size, key=key, dtype=dtype)
820
- # Wikipedia defines an intermediate x with the formula
821
- # x = loc + loc ** 2 * y / (2 * conc) - loc / (2 * conc) * sqrt(4 * loc * conc * y + loc ** 2 * y ** 2)
822
- # where y ~ N(0, 1)**2 (sampled_chi2 above) and conc is the concentration.
823
- # Let us write
824
- # w = loc * y / (2 * conc)
825
- # Then we can extract the common factor in the last two terms to obtain
826
- # x = loc + loc * w * (1 - sqrt(2 / w + 1))
827
- # Now we see that the Wikipedia formula suffers from catastrphic
828
- # cancellation for large w (e.g., if conc << loc).
829
- #
830
- # Fortunately, we can fix this by multiplying both sides
831
- # by 1 + sqrt(2 / w + 1). We get
832
- # x * (1 + sqrt(2 / w + 1)) =
833
- # = loc * (1 + sqrt(2 / w + 1)) + loc * w * (1 - (2 / w + 1))
834
- # = loc * (sqrt(2 / w + 1) - 1)
835
- # The term sqrt(2 / w + 1) + 1 no longer presents numerical
836
- # difficulties for large w, and sqrt(2 / w + 1) - 1 is just
837
- # sqrt1pm1(2 / w), which we know how to compute accurately.
838
- # This just leaves the matter of small w, where 2 / w may
839
- # overflow. In the limit a w -> 0, x -> loc, so we just mask
840
- # that case.
841
- sqrt1pm1_arg = 4 * scale / (mean * sampled_chi2) # 2 / w above
842
- safe_sqrt1pm1_arg = jnp.where(sqrt1pm1_arg < np.inf, sqrt1pm1_arg, 1.0)
843
- denominator = 1.0 + jnp.sqrt(safe_sqrt1pm1_arg + 1.0)
844
- ratio = jnp.expm1(0.5 * jnp.log1p(safe_sqrt1pm1_arg)) / denominator
845
- sampled = mean * jnp.where(sqrt1pm1_arg < np.inf, ratio, 1.0) # x above
846
- res = jnp.where(sampled_uniform <= mean / (mean + sampled),
847
- sampled,
848
- jnp.square(mean) / sampled)
849
- return res
850
-
851
- def t(self,
852
- df,
646
+ key = self.split_key() if key is None else _formalize_key(key)
647
+ alpha = _check_py_seq(alpha)
648
+ dtype = dtype or environ.dftype()
649
+ r = jr.dirichlet(key, alpha=alpha, shape=_size2shape(size), dtype=dtype)
650
+ return r
651
+
652
+ def geometric(self,
653
+ p,
654
+ size: Optional[Size] = None,
655
+ key: Optional[SeedOrKey] = None,
656
+ dtype: DTypeLike = None):
657
+ p = _check_py_seq(p)
658
+ if size is None:
659
+ size = jnp.shape(p)
660
+ key = self.split_key() if key is None else _formalize_key(key)
661
+ dtype = dtype or environ.dftype()
662
+ u_ = uniform_for_unit(key, size, dtype=dtype)
663
+ r = jnp.floor(jnp.log1p(-u_) / jnp.log1p(-p))
664
+ return r
665
+
666
+ def _check_p2(self, p):
667
+ raise ValueError(f'We require `sum(pvals[:-1]) <= 1`. But we got {p}')
668
+
669
+ def multinomial(self,
670
+ n,
671
+ pvals,
672
+ size: Optional[Size] = None,
673
+ key: Optional[SeedOrKey] = None,
674
+ dtype: DTypeLike = None):
675
+ key = self.split_key() if key is None else _formalize_key(key)
676
+ n = _check_py_seq(n)
677
+ pvals = _check_py_seq(pvals)
678
+ jit_error_if(jnp.sum(pvals[:-1]) > 1., self._check_p2, pvals)
679
+ if isinstance(n, jax.core.Tracer):
680
+ raise ValueError("The total count parameter `n` should not be a jax abstract array.")
681
+ size = _size2shape(size)
682
+ n_max = int(np.max(jax.device_get(n)))
683
+ batch_shape = lax.broadcast_shapes(jnp.shape(pvals)[:-1], jnp.shape(n))
684
+ r = _multinomial(key, pvals, n, n_max, batch_shape + size)
685
+ dtype = dtype or environ.ditype()
686
+ return jnp.asarray(r, dtype=dtype)
687
+
688
+ def multivariate_normal(
689
+ self,
690
+ mean,
691
+ cov,
853
692
  size: Optional[Size] = None,
693
+ method: str = 'cholesky',
854
694
  key: Optional[SeedOrKey] = None,
855
- dtype: DTypeLike = None):
856
- dtype = dtype or environ.dftype()
857
- df = jnp.asarray(_check_py_seq(df), dtype=dtype)
858
- if size is None:
859
- size = np.shape(df)
860
- else:
861
- size = _size2shape(size)
862
- _check_shape("t", size, np.shape(df))
863
- if key is None:
864
- keys = self.split_keys(2)
865
- else:
866
- keys = jr.split(_formalize_key(key), 2)
867
- n = jr.normal(keys[0], size, dtype=dtype)
868
- two = _const(n, 2)
869
- half_df = lax.div(df, two)
870
- g = jr.gamma(keys[1], half_df, size, dtype=dtype)
871
- r = n * jnp.sqrt(half_df / g)
872
- return r
873
-
874
- def orthogonal(self,
875
- n: int,
695
+ dtype: DTypeLike = None
696
+ ):
697
+ if method not in {'svd', 'eigh', 'cholesky'}:
698
+ raise ValueError("method must be one of {'svd', 'eigh', 'cholesky'}")
699
+ dtype = dtype or environ.dftype()
700
+ mean = u.math.asarray(_check_py_seq(mean), dtype=dtype)
701
+ cov = u.math.asarray(_check_py_seq(cov), dtype=dtype)
702
+ if isinstance(mean, u.Quantity):
703
+ assert isinstance(cov, u.Quantity)
704
+ assert mean.unit ** 2 == cov.unit
705
+ mean = mean.mantissa if isinstance(mean, u.Quantity) else mean
706
+ cov = cov.mantissa if isinstance(cov, u.Quantity) else cov
707
+ unit = mean.unit if isinstance(mean, u.Quantity) else u.Unit()
708
+
709
+ key = self.split_key() if key is None else _formalize_key(key)
710
+ if not jnp.ndim(mean) >= 1:
711
+ raise ValueError(f"multivariate_normal requires mean.ndim >= 1, got mean.ndim == {jnp.ndim(mean)}")
712
+ if not jnp.ndim(cov) >= 2:
713
+ raise ValueError(f"multivariate_normal requires cov.ndim >= 2, got cov.ndim == {jnp.ndim(cov)}")
714
+ n = mean.shape[-1]
715
+ if jnp.shape(cov)[-2:] != (n, n):
716
+ raise ValueError(f"multivariate_normal requires cov.shape == (..., n, n) for n={n}, "
717
+ f"but got cov.shape == {jnp.shape(cov)}.")
718
+ if size is None:
719
+ size = lax.broadcast_shapes(mean.shape[:-1], cov.shape[:-2])
720
+ else:
721
+ size = _size2shape(size)
722
+ _check_shape("normal", size, mean.shape[:-1], cov.shape[:-2])
723
+
724
+ if method == 'svd':
725
+ (u_, s, _) = jnp.linalg.svd(cov)
726
+ factor = u_ * jnp.sqrt(s[..., None, :])
727
+ elif method == 'eigh':
728
+ (w, v) = jnp.linalg.eigh(cov)
729
+ factor = v * jnp.sqrt(w[..., None, :])
730
+ else: # 'cholesky'
731
+ factor = jnp.linalg.cholesky(cov)
732
+ normal_samples = jr.normal(key, size + mean.shape[-1:], dtype=dtype)
733
+ r = mean + jnp.einsum('...ij,...j->...i', factor, normal_samples)
734
+ return r if unit.is_unitless else u.Quantity(r, unit=unit)
735
+
736
+ def rayleigh(self,
737
+ scale=1.0,
876
738
  size: Optional[Size] = None,
877
739
  key: Optional[SeedOrKey] = None,
878
740
  dtype: DTypeLike = None):
879
- dtype = dtype or environ.dftype()
880
- key = self.split_key() if key is None else _formalize_key(key)
881
- size = _size2shape(size)
882
- _check_shape("orthogonal", size)
883
- n = core.concrete_or_error(index, n, "The error occurred in jax.random.orthogonal()")
884
- z = jr.normal(key, size + (n, n), dtype=dtype)
885
- q, r = jnp.linalg.qr(z)
886
- d = jnp.diagonal(r, 0, -2, -1)
887
- r = q * jnp.expand_dims(d / abs(d), -2)
888
- return r
889
-
890
- def noncentral_chisquare(self,
891
- df,
892
- nonc,
893
- size: Optional[Size] = None,
894
- key: Optional[SeedOrKey] = None,
895
- dtype: DTypeLike = None):
896
- dtype = dtype or environ.dftype()
897
- df = jnp.asarray(_check_py_seq(df), dtype=dtype)
898
- nonc = jnp.asarray(_check_py_seq(nonc), dtype=dtype)
899
- if size is None:
900
- size = lax.broadcast_shapes(jnp.shape(df), jnp.shape(nonc))
901
- size = _size2shape(size)
902
- if key is None:
903
- keys = self.split_keys(3)
904
- else:
905
- keys = jr.split(_formalize_key(key), 3)
906
- i = jr.poisson(keys[0], 0.5 * nonc, shape=size, dtype=environ.ditype())
907
- n = jr.normal(keys[1], shape=size, dtype=dtype) + jnp.sqrt(nonc)
908
- cond = jnp.greater(df, 1.0)
909
- df2 = jnp.where(cond, df - 1.0, df + 2.0 * i)
910
- chi2 = 2.0 * jr.gamma(keys[2], 0.5 * df2, shape=size, dtype=dtype)
911
- r = jnp.where(cond, chi2 + n * n, chi2)
912
- return r
913
-
914
- def loggamma(self,
915
- a,
916
- size: Optional[Size] = None,
917
- key: Optional[SeedOrKey] = None,
918
- dtype: DTypeLike = None):
919
- dtype = dtype or environ.dftype()
920
- key = self.split_key() if key is None else _formalize_key(key)
921
- a = _check_py_seq(a)
922
- if size is None:
923
- size = jnp.shape(a)
924
- r = jr.loggamma(key, a, shape=_size2shape(size), dtype=dtype)
925
- return r
926
-
927
- def categorical(self,
928
- logits,
929
- axis: int = -1,
930
- size: Optional[Size] = None,
931
- key: Optional[SeedOrKey] = None):
932
- key = self.split_key() if key is None else _formalize_key(key)
933
- logits = _check_py_seq(logits)
934
- if size is None:
935
- size = list(jnp.shape(logits))
936
- size.pop(axis)
937
- r = jr.categorical(key, logits, axis=axis, shape=_size2shape(size))
938
- return r
939
-
940
- def zipf(self,
941
- a,
942
- size: Optional[Size] = None,
943
- key: Optional[SeedOrKey] = None,
944
- dtype: DTypeLike = None):
945
- a = _check_py_seq(a)
946
- if size is None:
947
- size = jnp.shape(a)
948
- dtype = dtype or environ.ditype()
949
- r = jax.pure_callback(lambda x: np.random.zipf(x, size).astype(dtype),
950
- jax.ShapeDtypeStruct(size, dtype),
951
- a)
952
- return r
953
-
954
- def power(self,
955
- a,
956
- size: Optional[Size] = None,
957
- key: Optional[SeedOrKey] = None,
958
- dtype: DTypeLike = None):
959
- a = _check_py_seq(a)
960
- if size is None:
961
- size = jnp.shape(a)
962
- size = _size2shape(size)
963
- dtype = dtype or environ.dftype()
964
- r = jax.pure_callback(lambda a: np.random.power(a=a, size=size).astype(dtype),
965
- jax.ShapeDtypeStruct(size, dtype),
966
- a)
967
- return r
968
-
969
- def f(self,
970
- dfnum,
971
- dfden,
972
- size: Optional[Size] = None,
973
- key: Optional[SeedOrKey] = None,
974
- dtype: DTypeLike = None):
975
- dfnum = _check_py_seq(dfnum)
976
- dfden = _check_py_seq(dfden)
977
- if size is None:
978
- size = jnp.broadcast_shapes(jnp.shape(dfnum), jnp.shape(dfden))
979
- size = _size2shape(size)
980
- d = {'dfnum': dfnum, 'dfden': dfden}
981
- dtype = dtype or environ.dftype()
982
- r = jax.pure_callback(lambda dfnum_, dfden_: np.random.f(dfnum=dfnum_,
983
- dfden=dfden_,
984
- size=size).astype(dtype),
985
- jax.ShapeDtypeStruct(size, dtype),
986
- dfnum, dfden)
987
- return r
988
-
989
- def hypergeometric(
990
- self,
991
- ngood,
992
- nbad,
993
- nsample,
994
- size: Optional[Size] = None,
995
- key: Optional[SeedOrKey] = None,
996
- dtype: DTypeLike = None
997
- ):
998
- ngood = _check_py_seq(ngood)
999
- nbad = _check_py_seq(nbad)
1000
- nsample = _check_py_seq(nsample)
1001
-
1002
- if size is None:
1003
- size = lax.broadcast_shapes(jnp.shape(ngood),
1004
- jnp.shape(nbad),
1005
- jnp.shape(nsample))
1006
- size = _size2shape(size)
1007
- dtype = dtype or environ.ditype()
1008
- d = {'ngood': ngood, 'nbad': nbad, 'nsample': nsample}
1009
- r = jax.pure_callback(lambda d: np.random.hypergeometric(ngood=d['ngood'],
1010
- nbad=d['nbad'],
1011
- nsample=d['nsample'],
1012
- size=size).astype(dtype),
1013
- jax.ShapeDtypeStruct(size, dtype),
1014
- d)
1015
- return r
1016
-
1017
- def logseries(self,
1018
- p,
741
+ scale = _check_py_seq(scale)
742
+ if size is None:
743
+ size = jnp.shape(scale)
744
+ key = self.split_key() if key is None else _formalize_key(key)
745
+ dtype = dtype or environ.dftype()
746
+ x = jnp.sqrt(-2. * jnp.log(uniform_for_unit(key, shape=_size2shape(size), minval=0, maxval=1, dtype=dtype)))
747
+ r = x * scale
748
+ return r
749
+
750
+ def triangular(self,
751
+ size: Optional[Size] = None,
752
+ key: Optional[SeedOrKey] = None):
753
+ key = self.split_key() if key is None else _formalize_key(key)
754
+ bernoulli_samples = jr.bernoulli(key, p=0.5, shape=_size2shape(size))
755
+ r = 2 * bernoulli_samples - 1
756
+ return r
757
+
758
+ def vonmises(self,
759
+ mu,
760
+ kappa,
761
+ size: Optional[Size] = None,
762
+ key: Optional[SeedOrKey] = None,
763
+ dtype: DTypeLike = None):
764
+ key = self.split_key() if key is None else _formalize_key(key)
765
+ dtype = dtype or environ.dftype()
766
+ mu = jnp.asarray(_check_py_seq(mu), dtype=dtype)
767
+ kappa = jnp.asarray(_check_py_seq(kappa), dtype=dtype)
768
+ if size is None:
769
+ size = lax.broadcast_shapes(jnp.shape(mu), jnp.shape(kappa))
770
+ size = _size2shape(size)
771
+ samples = _von_mises_centered(key, kappa, size, dtype=dtype)
772
+ samples = samples + mu
773
+ samples = (samples + jnp.pi) % (2.0 * jnp.pi) - jnp.pi
774
+ return samples
775
+
776
+ def weibull(self,
777
+ a,
1019
778
  size: Optional[Size] = None,
1020
779
  key: Optional[SeedOrKey] = None,
1021
780
  dtype: DTypeLike = None):
1022
- p = _check_py_seq(p)
1023
- if size is None:
1024
- size = jnp.shape(p)
1025
- size = _size2shape(size)
1026
- dtype = dtype or environ.ditype()
1027
- r = jax.pure_callback(lambda p: np.random.logseries(p=p, size=size).astype(dtype),
1028
- jax.ShapeDtypeStruct(size, dtype),
1029
- p)
1030
- return r
1031
-
1032
- def noncentral_f(self,
1033
- dfnum,
1034
- dfden,
1035
- nonc,
781
+ key = self.split_key() if key is None else _formalize_key(key)
782
+ a = _check_py_seq(a)
783
+ if size is None:
784
+ size = jnp.shape(a)
785
+ else:
786
+ if jnp.size(a) > 1:
787
+ raise ValueError(f'"a" should be a scalar when "size" is provided. But we got {a}')
788
+ size = _size2shape(size)
789
+ dtype = dtype or environ.dftype()
790
+ random_uniform = uniform_for_unit(key=key, shape=size, minval=0, maxval=1, dtype=dtype)
791
+ r = jnp.power(-jnp.log1p(-random_uniform), 1.0 / a)
792
+ return r
793
+
794
+ def weibull_min(self,
795
+ a,
796
+ scale=None,
797
+ size: Optional[Size] = None,
798
+ key: Optional[SeedOrKey] = None,
799
+ dtype: DTypeLike = None):
800
+ key = self.split_key() if key is None else _formalize_key(key)
801
+ a = _check_py_seq(a)
802
+ scale = _check_py_seq(scale)
803
+ if size is None:
804
+ size = jnp.broadcast_shapes(jnp.shape(a), jnp.shape(scale))
805
+ else:
806
+ if jnp.size(a) > 1:
807
+ raise ValueError(f'"a" should be a scalar when "size" is provided. But we got {a}')
808
+ size = _size2shape(size)
809
+ dtype = dtype or environ.dftype()
810
+ random_uniform = uniform_for_unit(key=key, shape=size, minval=0, maxval=1, dtype=dtype)
811
+ r = jnp.power(-jnp.log1p(-random_uniform), 1.0 / a)
812
+ if scale is not None:
813
+ r /= scale
814
+ return r
815
+
816
+ def maxwell(self,
817
+ size: Optional[Size] = None,
818
+ key: Optional[SeedOrKey] = None,
819
+ dtype: DTypeLike = None):
820
+ key = self.split_key() if key is None else _formalize_key(key)
821
+ shape = _size2shape(size) + (3,)
822
+ dtype = dtype or environ.dftype()
823
+ norm_rvs = jr.normal(key=key, shape=shape, dtype=dtype)
824
+ r = jnp.linalg.norm(norm_rvs, axis=-1)
825
+ return r
826
+
827
+ def negative_binomial(self,
828
+ n,
829
+ p,
830
+ size: Optional[Size] = None,
831
+ key: Optional[SeedOrKey] = None,
832
+ dtype: DTypeLike = None):
833
+ n = _check_py_seq(n)
834
+ p = _check_py_seq(p)
835
+ if size is None:
836
+ size = lax.broadcast_shapes(jnp.shape(n), jnp.shape(p))
837
+ size = _size2shape(size)
838
+ logits = jnp.log(p) - jnp.log1p(-p)
839
+ if key is None:
840
+ keys = self.split_key(2)
841
+ else:
842
+ keys = jr.split(_formalize_key(key), 2)
843
+ rate = self.gamma(shape=n, scale=jnp.exp(-logits), size=size, key=keys[0], dtype=environ.dftype())
844
+ r = self.poisson(lam=rate, key=keys[1], dtype=dtype or environ.ditype())
845
+ return r
846
+
847
+ def wald(self,
848
+ mean,
849
+ scale,
850
+ size: Optional[Size] = None,
851
+ key: Optional[SeedOrKey] = None,
852
+ dtype: DTypeLike = None):
853
+ dtype = dtype or environ.dftype()
854
+ key = self.split_key() if key is None else _formalize_key(key)
855
+ mean = jnp.asarray(_check_py_seq(mean), dtype=dtype)
856
+ scale = jnp.asarray(_check_py_seq(scale), dtype=dtype)
857
+ if size is None:
858
+ size = lax.broadcast_shapes(jnp.shape(mean), jnp.shape(scale))
859
+ size = _size2shape(size)
860
+ sampled_chi2 = jnp.square(self.randn(*size))
861
+ sampled_uniform = self.uniform(size=size, key=key, dtype=dtype)
862
+ # Wikipedia defines an intermediate x with the formula
863
+ # x = loc + loc ** 2 * y / (2 * conc) - loc / (2 * conc) * sqrt(4 * loc * conc * y + loc ** 2 * y ** 2)
864
+ # where y ~ N(0, 1)**2 (sampled_chi2 above) and conc is the concentration.
865
+ # Let us write
866
+ # w = loc * y / (2 * conc)
867
+ # Then we can extract the common factor in the last two terms to obtain
868
+ # x = loc + loc * w * (1 - sqrt(2 / w + 1))
869
+ # Now we see that the Wikipedia formula suffers from catastrphic
870
+ # cancellation for large w (e.g., if conc << loc).
871
+ #
872
+ # Fortunately, we can fix this by multiplying both sides
873
+ # by 1 + sqrt(2 / w + 1). We get
874
+ # x * (1 + sqrt(2 / w + 1)) =
875
+ # = loc * (1 + sqrt(2 / w + 1)) + loc * w * (1 - (2 / w + 1))
876
+ # = loc * (sqrt(2 / w + 1) - 1)
877
+ # The term sqrt(2 / w + 1) + 1 no longer presents numerical
878
+ # difficulties for large w, and sqrt(2 / w + 1) - 1 is just
879
+ # sqrt1pm1(2 / w), which we know how to compute accurately.
880
+ # This just leaves the matter of small w, where 2 / w may
881
+ # overflow. In the limit a w -> 0, x -> loc, so we just mask
882
+ # that case.
883
+ sqrt1pm1_arg = 4 * scale / (mean * sampled_chi2) # 2 / w above
884
+ safe_sqrt1pm1_arg = jnp.where(sqrt1pm1_arg < np.inf, sqrt1pm1_arg, 1.0)
885
+ denominator = 1.0 + jnp.sqrt(safe_sqrt1pm1_arg + 1.0)
886
+ ratio = jnp.expm1(0.5 * jnp.log1p(safe_sqrt1pm1_arg)) / denominator
887
+ sampled = mean * jnp.where(sqrt1pm1_arg < np.inf, ratio, 1.0) # x above
888
+ res = jnp.where(sampled_uniform <= mean / (mean + sampled),
889
+ sampled,
890
+ jnp.square(mean) / sampled)
891
+ return res
892
+
893
+ def t(self,
894
+ df,
895
+ size: Optional[Size] = None,
896
+ key: Optional[SeedOrKey] = None,
897
+ dtype: DTypeLike = None):
898
+ dtype = dtype or environ.dftype()
899
+ df = jnp.asarray(_check_py_seq(df), dtype=dtype)
900
+ if size is None:
901
+ size = np.shape(df)
902
+ else:
903
+ size = _size2shape(size)
904
+ _check_shape("t", size, np.shape(df))
905
+ if key is None:
906
+ keys = self.split_key(2)
907
+ else:
908
+ keys = jr.split(_formalize_key(key), 2)
909
+ n = jr.normal(keys[0], size, dtype=dtype)
910
+ two = _const(n, 2)
911
+ half_df = lax.div(df, two)
912
+ g = jr.gamma(keys[1], half_df, size, dtype=dtype)
913
+ r = n * jnp.sqrt(half_df / g)
914
+ return r
915
+
916
+ def orthogonal(self,
917
+ n: int,
1036
918
  size: Optional[Size] = None,
1037
919
  key: Optional[SeedOrKey] = None,
1038
920
  dtype: DTypeLike = None):
1039
- dfnum = _check_py_seq(dfnum)
1040
- dfden = _check_py_seq(dfden)
1041
- nonc = _check_py_seq(nonc)
1042
- if size is None:
1043
- size = lax.broadcast_shapes(jnp.shape(dfnum),
1044
- jnp.shape(dfden),
1045
- jnp.shape(nonc))
1046
- size = _size2shape(size)
1047
- d = {'dfnum': dfnum, 'dfden': dfden, 'nonc': nonc}
1048
- dtype = dtype or environ.dftype()
1049
- r = jax.pure_callback(lambda x: np.random.noncentral_f(dfnum=x['dfnum'],
1050
- dfden=x['dfden'],
1051
- nonc=x['nonc'],
1052
- size=size).astype(dtype),
1053
- jax.ShapeDtypeStruct(size, dtype),
1054
- d)
1055
- return r
1056
-
1057
- # PyTorch compatibility #
1058
- # --------------------- #
1059
-
1060
- def rand_like(self, input, *, dtype=None, key: Optional[SeedOrKey] = None):
1061
- """Returns a tensor with the same size as input that is filled with random
1062
- numbers from a uniform distribution on the interval ``[0, 1)``.
1063
-
1064
- Args:
1065
- input: the ``size`` of input will determine size of the output tensor.
1066
- dtype: the desired data type of returned Tensor. Default: if ``None``, defaults to the dtype of input.
1067
- key: the seed or key for the random.
1068
-
1069
- Returns:
1070
- The random data.
1071
- """
1072
- return self.random(jnp.shape(input), key=key).astype(dtype)
1073
-
1074
- def randn_like(self, input, *, dtype=None, key: Optional[SeedOrKey] = None):
1075
- """Returns a tensor with the same size as ``input`` that is filled with
1076
- random numbers from a normal distribution with mean 0 and variance 1.
1077
-
1078
- Args:
1079
- input: the ``size`` of input will determine size of the output tensor.
1080
- dtype: the desired data type of returned Tensor. Default: if ``None``, defaults to the dtype of input.
1081
- key: the seed or key for the random.
1082
-
1083
- Returns:
1084
- The random data.
1085
- """
1086
- return self.randn(*jnp.shape(input), key=key).astype(dtype)
1087
-
1088
- def randint_like(self, input, low=0, high=None, *, dtype=None, key: Optional[SeedOrKey] = None):
1089
- if high is None:
1090
- high = max(input)
1091
- return self.randint(low, high=high, size=jnp.shape(input), dtype=dtype, key=key)
921
+ dtype = dtype or environ.dftype()
922
+ key = self.split_key() if key is None else _formalize_key(key)
923
+ size = _size2shape(size)
924
+ _check_shape("orthogonal", size)
925
+ n = core.concrete_or_error(index, n, "The error occurred in jax.random.orthogonal()")
926
+ z = jr.normal(key, size + (n, n), dtype=dtype)
927
+ q, r = jnp.linalg.qr(z)
928
+ d = jnp.diagonal(r, 0, -2, -1)
929
+ r = q * jnp.expand_dims(d / abs(d), -2)
930
+ return r
931
+
932
+ def noncentral_chisquare(self,
933
+ df,
934
+ nonc,
935
+ size: Optional[Size] = None,
936
+ key: Optional[SeedOrKey] = None,
937
+ dtype: DTypeLike = None):
938
+ dtype = dtype or environ.dftype()
939
+ df = jnp.asarray(_check_py_seq(df), dtype=dtype)
940
+ nonc = jnp.asarray(_check_py_seq(nonc), dtype=dtype)
941
+ if size is None:
942
+ size = lax.broadcast_shapes(jnp.shape(df), jnp.shape(nonc))
943
+ size = _size2shape(size)
944
+ if key is None:
945
+ keys = self.split_key(3)
946
+ else:
947
+ keys = jr.split(_formalize_key(key), 3)
948
+ i = jr.poisson(keys[0], 0.5 * nonc, shape=size, dtype=environ.ditype())
949
+ n = jr.normal(keys[1], shape=size, dtype=dtype) + jnp.sqrt(nonc)
950
+ cond = jnp.greater(df, 1.0)
951
+ df2 = jnp.where(cond, df - 1.0, df + 2.0 * i)
952
+ chi2 = 2.0 * jr.gamma(keys[2], 0.5 * df2, shape=size, dtype=dtype)
953
+ r = jnp.where(cond, chi2 + n * n, chi2)
954
+ return r
955
+
956
+ def loggamma(self,
957
+ a,
958
+ size: Optional[Size] = None,
959
+ key: Optional[SeedOrKey] = None,
960
+ dtype: DTypeLike = None):
961
+ dtype = dtype or environ.dftype()
962
+ key = self.split_key() if key is None else _formalize_key(key)
963
+ a = _check_py_seq(a)
964
+ if size is None:
965
+ size = jnp.shape(a)
966
+ r = jr.loggamma(key, a, shape=_size2shape(size), dtype=dtype)
967
+ return r
968
+
969
+ def categorical(self,
970
+ logits,
971
+ axis: int = -1,
972
+ size: Optional[Size] = None,
973
+ key: Optional[SeedOrKey] = None):
974
+ key = self.split_key() if key is None else _formalize_key(key)
975
+ logits = _check_py_seq(logits)
976
+ if size is None:
977
+ size = list(jnp.shape(logits))
978
+ size.pop(axis)
979
+ r = jr.categorical(key, logits, axis=axis, shape=_size2shape(size))
980
+ return r
981
+
982
+ def zipf(self,
983
+ a,
984
+ size: Optional[Size] = None,
985
+ key: Optional[SeedOrKey] = None,
986
+ dtype: DTypeLike = None):
987
+ a = _check_py_seq(a)
988
+ if size is None:
989
+ size = jnp.shape(a)
990
+ dtype = dtype or environ.ditype()
991
+ r = jax.pure_callback(lambda x: np.random.zipf(x, size).astype(dtype),
992
+ jax.ShapeDtypeStruct(size, dtype),
993
+ a)
994
+ return r
995
+
996
+ def power(self,
997
+ a,
998
+ size: Optional[Size] = None,
999
+ key: Optional[SeedOrKey] = None,
1000
+ dtype: DTypeLike = None):
1001
+ a = _check_py_seq(a)
1002
+ if size is None:
1003
+ size = jnp.shape(a)
1004
+ size = _size2shape(size)
1005
+ dtype = dtype or environ.dftype()
1006
+ r = jax.pure_callback(lambda a: np.random.power(a=a, size=size).astype(dtype),
1007
+ jax.ShapeDtypeStruct(size, dtype),
1008
+ a)
1009
+ return r
1010
+
1011
+ def f(self,
1012
+ dfnum,
1013
+ dfden,
1014
+ size: Optional[Size] = None,
1015
+ key: Optional[SeedOrKey] = None,
1016
+ dtype: DTypeLike = None):
1017
+ dfnum = _check_py_seq(dfnum)
1018
+ dfden = _check_py_seq(dfden)
1019
+ if size is None:
1020
+ size = jnp.broadcast_shapes(jnp.shape(dfnum), jnp.shape(dfden))
1021
+ size = _size2shape(size)
1022
+ d = {'dfnum': dfnum, 'dfden': dfden}
1023
+ dtype = dtype or environ.dftype()
1024
+ r = jax.pure_callback(lambda dfnum_, dfden_: np.random.f(dfnum=dfnum_,
1025
+ dfden=dfden_,
1026
+ size=size).astype(dtype),
1027
+ jax.ShapeDtypeStruct(size, dtype),
1028
+ dfnum, dfden)
1029
+ return r
1030
+
1031
+ def hypergeometric(
1032
+ self,
1033
+ ngood,
1034
+ nbad,
1035
+ nsample,
1036
+ size: Optional[Size] = None,
1037
+ key: Optional[SeedOrKey] = None,
1038
+ dtype: DTypeLike = None
1039
+ ):
1040
+ ngood = _check_py_seq(ngood)
1041
+ nbad = _check_py_seq(nbad)
1042
+ nsample = _check_py_seq(nsample)
1043
+
1044
+ if size is None:
1045
+ size = lax.broadcast_shapes(jnp.shape(ngood),
1046
+ jnp.shape(nbad),
1047
+ jnp.shape(nsample))
1048
+ size = _size2shape(size)
1049
+ dtype = dtype or environ.ditype()
1050
+ d = {'ngood': ngood, 'nbad': nbad, 'nsample': nsample}
1051
+ r = jax.pure_callback(lambda d: np.random.hypergeometric(ngood=d['ngood'],
1052
+ nbad=d['nbad'],
1053
+ nsample=d['nsample'],
1054
+ size=size).astype(dtype),
1055
+ jax.ShapeDtypeStruct(size, dtype),
1056
+ d)
1057
+ return r
1058
+
1059
+ def logseries(self,
1060
+ p,
1061
+ size: Optional[Size] = None,
1062
+ key: Optional[SeedOrKey] = None,
1063
+ dtype: DTypeLike = None):
1064
+ p = _check_py_seq(p)
1065
+ if size is None:
1066
+ size = jnp.shape(p)
1067
+ size = _size2shape(size)
1068
+ dtype = dtype or environ.ditype()
1069
+ r = jax.pure_callback(lambda p: np.random.logseries(p=p, size=size).astype(dtype),
1070
+ jax.ShapeDtypeStruct(size, dtype),
1071
+ p)
1072
+ return r
1073
+
1074
+ def noncentral_f(self,
1075
+ dfnum,
1076
+ dfden,
1077
+ nonc,
1078
+ size: Optional[Size] = None,
1079
+ key: Optional[SeedOrKey] = None,
1080
+ dtype: DTypeLike = None):
1081
+ dfnum = _check_py_seq(dfnum)
1082
+ dfden = _check_py_seq(dfden)
1083
+ nonc = _check_py_seq(nonc)
1084
+ if size is None:
1085
+ size = lax.broadcast_shapes(jnp.shape(dfnum),
1086
+ jnp.shape(dfden),
1087
+ jnp.shape(nonc))
1088
+ size = _size2shape(size)
1089
+ d = {'dfnum': dfnum, 'dfden': dfden, 'nonc': nonc}
1090
+ dtype = dtype or environ.dftype()
1091
+ r = jax.pure_callback(lambda x: np.random.noncentral_f(dfnum=x['dfnum'],
1092
+ dfden=x['dfden'],
1093
+ nonc=x['nonc'],
1094
+ size=size).astype(dtype),
1095
+ jax.ShapeDtypeStruct(size, dtype),
1096
+ d)
1097
+ return r
1098
+
1099
+ # PyTorch compatibility #
1100
+ # --------------------- #
1101
+
1102
+ def rand_like(self, input, *, dtype=None, key: Optional[SeedOrKey] = None):
1103
+ """Returns a tensor with the same size as input that is filled with random
1104
+ numbers from a uniform distribution on the interval ``[0, 1)``.
1105
+
1106
+ Args:
1107
+ input: the ``size`` of input will determine size of the output tensor.
1108
+ dtype: the desired data type of returned Tensor. Default: if ``None``, defaults to the dtype of input.
1109
+ key: the seed or key for the random.
1110
+
1111
+ Returns:
1112
+ The random data.
1113
+ """
1114
+ return self.random(jnp.shape(input), key=key).astype(dtype)
1115
+
1116
+ def randn_like(self, input, *, dtype=None, key: Optional[SeedOrKey] = None):
1117
+ """Returns a tensor with the same size as ``input`` that is filled with
1118
+ random numbers from a normal distribution with mean 0 and variance 1.
1119
+
1120
+ Args:
1121
+ input: the ``size`` of input will determine size of the output tensor.
1122
+ dtype: the desired data type of returned Tensor. Default: if ``None``, defaults to the dtype of input.
1123
+ key: the seed or key for the random.
1124
+
1125
+ Returns:
1126
+ The random data.
1127
+ """
1128
+ return self.randn(*jnp.shape(input), key=key).astype(dtype)
1129
+
1130
+ def randint_like(self, input, low=0, high=None, *, dtype=None, key: Optional[SeedOrKey] = None):
1131
+ if high is None:
1132
+ high = max(input)
1133
+ return self.randint(low, high=high, size=jnp.shape(input), dtype=dtype, key=key)
1092
1134
 
1093
1135
 
1094
1136
  # default random generator
@@ -1099,393 +1141,393 @@ DEFAULT = RandomState(np.random.randint(0, 10000, size=2, dtype=np.uint32))
1099
1141
 
1100
1142
 
1101
1143
  def _formalize_key(key):
1102
- if isinstance(key, int):
1103
- return jr.PRNGKey(key)
1104
- elif isinstance(key, (jax.Array, np.ndarray)):
1105
- if key.dtype != jnp.uint32:
1106
- raise TypeError('key must be a int or an array with two uint32.')
1107
- if key.size != 2:
1108
- raise TypeError('key must be a int or an array with two uint32.')
1109
- return jnp.asarray(key, dtype=jnp.uint32)
1110
- else:
1111
- raise TypeError('key must be a int or an array with two uint32.')
1144
+ if isinstance(key, int):
1145
+ return jr.PRNGKey(key)
1146
+ elif isinstance(key, (jax.Array, np.ndarray)):
1147
+ if key.dtype != jnp.uint32:
1148
+ raise TypeError('key must be a int or an array with two uint32.')
1149
+ if key.size != 2:
1150
+ raise TypeError('key must be a int or an array with two uint32.')
1151
+ return jnp.asarray(key, dtype=jnp.uint32)
1152
+ else:
1153
+ raise TypeError('key must be a int or an array with two uint32.')
1112
1154
 
1113
1155
 
1114
1156
  def _size2shape(size):
1115
- if size is None:
1116
- return ()
1117
- elif isinstance(size, (tuple, list)):
1118
- return tuple(size)
1119
- else:
1120
- return (size,)
1157
+ if size is None:
1158
+ return ()
1159
+ elif isinstance(size, (tuple, list)):
1160
+ return tuple(size)
1161
+ else:
1162
+ return (size,)
1121
1163
 
1122
1164
 
1123
1165
  def _check_shape(name, shape, *param_shapes):
1124
- if param_shapes:
1125
- shape_ = lax.broadcast_shapes(shape, *param_shapes)
1126
- if shape != shape_:
1127
- msg = ("{} parameter shapes must be broadcast-compatible with shape "
1128
- "argument, and the result of broadcasting the shapes must equal "
1129
- "the shape argument, but got result {} for shape argument {}.")
1130
- raise ValueError(msg.format(name, shape_, shape))
1166
+ if param_shapes:
1167
+ shape_ = lax.broadcast_shapes(shape, *param_shapes)
1168
+ if shape != shape_:
1169
+ msg = ("{} parameter shapes must be broadcast-compatible with shape "
1170
+ "argument, and the result of broadcasting the shapes must equal "
1171
+ "the shape argument, but got result {} for shape argument {}.")
1172
+ raise ValueError(msg.format(name, shape_, shape))
1131
1173
 
1132
1174
 
1133
1175
  def _is_python_scalar(x):
1134
- if hasattr(x, 'aval'):
1135
- return x.aval.weak_type
1136
- elif np.ndim(x) == 0:
1137
- return True
1138
- elif isinstance(x, (bool, int, float, complex)):
1139
- return True
1140
- else:
1141
- return False
1176
+ if hasattr(x, 'aval'):
1177
+ return x.aval.weak_type
1178
+ elif np.ndim(x) == 0:
1179
+ return True
1180
+ elif isinstance(x, (bool, int, float, complex)):
1181
+ return True
1182
+ else:
1183
+ return False
1142
1184
 
1143
1185
 
1144
1186
  python_scalar_dtypes = {
1145
- bool: np.dtype('bool'),
1146
- int: np.dtype('int64'),
1147
- float: np.dtype('float64'),
1148
- complex: np.dtype('complex128'),
1187
+ bool: np.dtype('bool'),
1188
+ int: np.dtype('int64'),
1189
+ float: np.dtype('float64'),
1190
+ complex: np.dtype('complex128'),
1149
1191
  }
1150
1192
 
1151
1193
 
1152
1194
  def _dtype(x, *, canonicalize: bool = False):
1153
- """Return the dtype object for a value or type, optionally canonicalized based on X64 mode."""
1154
- if x is None:
1155
- raise ValueError(f"Invalid argument to dtype: {x}.")
1156
- elif isinstance(x, type) and x in python_scalar_dtypes:
1157
- dt = python_scalar_dtypes[x]
1158
- elif type(x) in python_scalar_dtypes:
1159
- dt = python_scalar_dtypes[type(x)]
1160
- elif hasattr(x, 'dtype'):
1161
- dt = x.dtype
1162
- else:
1163
- dt = np.result_type(x)
1164
- return dtypes.canonicalize_dtype(dt) if canonicalize else dt
1195
+ """Return the dtype object for a value or type, optionally canonicalized based on X64 mode."""
1196
+ if x is None:
1197
+ raise ValueError(f"Invalid argument to dtype: {x}.")
1198
+ elif isinstance(x, type) and x in python_scalar_dtypes:
1199
+ dt = python_scalar_dtypes[x]
1200
+ elif type(x) in python_scalar_dtypes:
1201
+ dt = python_scalar_dtypes[type(x)]
1202
+ elif hasattr(x, 'dtype'):
1203
+ dt = x.dtype
1204
+ else:
1205
+ dt = np.result_type(x)
1206
+ return dtypes.canonicalize_dtype(dt) if canonicalize else dt
1165
1207
 
1166
1208
 
1167
1209
  def _const(example, val):
1168
- if _is_python_scalar(example):
1169
- dtype = dtypes.canonicalize_dtype(type(example))
1170
- val = dtypes.scalar_type_of(example)(val)
1171
- return val if dtype == _dtype(val, canonicalize=True) else np.array(val, dtype)
1172
- else:
1173
- dtype = dtypes.canonicalize_dtype(example.dtype)
1174
- return np.array(val, dtype)
1210
+ if _is_python_scalar(example):
1211
+ dtype = dtypes.canonicalize_dtype(type(example))
1212
+ val = dtypes.scalar_type_of(example)(val)
1213
+ return val if dtype == _dtype(val, canonicalize=True) else np.array(val, dtype)
1214
+ else:
1215
+ dtype = dtypes.canonicalize_dtype(example.dtype)
1216
+ return np.array(val, dtype)
1175
1217
 
1176
1218
 
1177
1219
  _tr_params = namedtuple(
1178
- "tr_params", ["c", "b", "a", "alpha", "u_r", "v_r", "m", "log_p", "log1_p", "log_h"]
1220
+ "tr_params", ["c", "b", "a", "alpha", "u_r", "v_r", "m", "log_p", "log1_p", "log_h"]
1179
1221
  )
1180
1222
 
1181
1223
 
1182
1224
  def _get_tr_params(n, p):
1183
- # See Table 1. Additionally, we pre-compute log(p), log1(-p) and the
1184
- # constant terms, that depend only on (n, p, m) in log(f(k)) (bottom of page 5).
1185
- mu = n * p
1186
- spq = jnp.sqrt(mu * (1 - p))
1187
- c = mu + 0.5
1188
- b = 1.15 + 2.53 * spq
1189
- a = -0.0873 + 0.0248 * b + 0.01 * p
1190
- alpha = (2.83 + 5.1 / b) * spq
1191
- u_r = 0.43
1192
- v_r = 0.92 - 4.2 / b
1193
- m = jnp.floor((n + 1) * p).astype(n.dtype)
1194
- log_p = jnp.log(p)
1195
- log1_p = jnp.log1p(-p)
1196
- log_h = ((m + 0.5) * (jnp.log((m + 1.0) / (n - m + 1.0)) + log1_p - log_p) +
1197
- _stirling_approx_tail(m) + _stirling_approx_tail(n - m))
1198
- return _tr_params(c, b, a, alpha, u_r, v_r, m, log_p, log1_p, log_h)
1225
+ # See Table 1. Additionally, we pre-compute log(p), log1(-p) and the
1226
+ # constant terms, that depend only on (n, p, m) in log(f(k)) (bottom of page 5).
1227
+ mu = n * p
1228
+ spq = jnp.sqrt(mu * (1 - p))
1229
+ c = mu + 0.5
1230
+ b = 1.15 + 2.53 * spq
1231
+ a = -0.0873 + 0.0248 * b + 0.01 * p
1232
+ alpha = (2.83 + 5.1 / b) * spq
1233
+ u_r = 0.43
1234
+ v_r = 0.92 - 4.2 / b
1235
+ m = jnp.floor((n + 1) * p).astype(n.dtype)
1236
+ log_p = jnp.log(p)
1237
+ log1_p = jnp.log1p(-p)
1238
+ log_h = ((m + 0.5) * (jnp.log((m + 1.0) / (n - m + 1.0)) + log1_p - log_p) +
1239
+ _stirling_approx_tail(m) + _stirling_approx_tail(n - m))
1240
+ return _tr_params(c, b, a, alpha, u_r, v_r, m, log_p, log1_p, log_h)
1199
1241
 
1200
1242
 
1201
1243
  def _stirling_approx_tail(k):
1202
- precomputed = jnp.array([0.08106146679532726,
1203
- 0.04134069595540929,
1204
- 0.02767792568499834,
1205
- 0.02079067210376509,
1206
- 0.01664469118982119,
1207
- 0.01387612882307075,
1208
- 0.01189670994589177,
1209
- 0.01041126526197209,
1210
- 0.009255462182712733,
1211
- 0.008330563433362871],
1212
- dtype=environ.dftype())
1213
- kp1 = k + 1
1214
- kp1sq = (k + 1) ** 2
1215
- return jnp.where(k < 10,
1216
- precomputed[k],
1217
- (1.0 / 12 - (1.0 / 360 - (1.0 / 1260) / kp1sq) / kp1sq) / kp1)
1244
+ precomputed = jnp.array([0.08106146679532726,
1245
+ 0.04134069595540929,
1246
+ 0.02767792568499834,
1247
+ 0.02079067210376509,
1248
+ 0.01664469118982119,
1249
+ 0.01387612882307075,
1250
+ 0.01189670994589177,
1251
+ 0.01041126526197209,
1252
+ 0.009255462182712733,
1253
+ 0.008330563433362871],
1254
+ dtype=environ.dftype())
1255
+ kp1 = k + 1
1256
+ kp1sq = (k + 1) ** 2
1257
+ return jnp.where(k < 10,
1258
+ precomputed[k],
1259
+ (1.0 / 12 - (1.0 / 360 - (1.0 / 1260) / kp1sq) / kp1sq) / kp1)
1218
1260
 
1219
1261
 
1220
1262
  def _binomial_btrs(key, p, n):
1221
- """
1222
- Based on the transformed rejection sampling algorithm (BTRS) from the
1223
- following reference:
1224
-
1225
- Hormann, "The Generation of Binonmial Random Variates"
1226
- (https://core.ac.uk/download/pdf/11007254.pdf)
1227
- """
1228
-
1229
- def _btrs_body_fn(val):
1230
- _, key, _, _ = val
1231
- key, key_u, key_v = jr.split(key, 3)
1232
- u = jr.uniform(key_u)
1233
- v = jr.uniform(key_v)
1234
- u = u - 0.5
1235
- k = jnp.floor(
1236
- (2 * tr_params.a / (0.5 - jnp.abs(u)) + tr_params.b) * u + tr_params.c
1237
- ).astype(n.dtype)
1238
- return k, key, u, v
1239
-
1240
- def _btrs_cond_fn(val):
1241
- def accept_fn(k, u, v):
1242
- # See acceptance condition in Step 3. (Page 3) of TRS algorithm
1243
- # v <= f(k) * g_grad(u) / alpha
1244
-
1245
- m = tr_params.m
1246
- log_p = tr_params.log_p
1247
- log1_p = tr_params.log1_p
1248
- # See: formula for log(f(k)) at bottom of Page 5.
1249
- log_f = (
1250
- (n + 1.0) * jnp.log((n - m + 1.0) / (n - k + 1.0))
1251
- + (k + 0.5) * (jnp.log((n - k + 1.0) / (k + 1.0)) + log_p - log1_p)
1252
- + (_stirling_approx_tail(k) - _stirling_approx_tail(n - k))
1253
- + tr_params.log_h
1254
- )
1255
- g = (tr_params.a / (0.5 - jnp.abs(u)) ** 2) + tr_params.b
1256
- return jnp.log((v * tr_params.alpha) / g) <= log_f
1257
-
1258
- k, key, u, v = val
1259
- early_accept = (jnp.abs(u) <= tr_params.u_r) & (v <= tr_params.v_r)
1260
- early_reject = (k < 0) | (k > n)
1261
- return lax.cond(
1262
- early_accept | early_reject,
1263
- (),
1264
- lambda _: ~early_accept,
1265
- (k, u, v),
1266
- lambda x: ~accept_fn(*x),
1267
- )
1263
+ """
1264
+ Based on the transformed rejection sampling algorithm (BTRS) from the
1265
+ following reference:
1268
1266
 
1269
- tr_params = _get_tr_params(n, p)
1270
- ret = lax.while_loop(
1271
- _btrs_cond_fn, _btrs_body_fn, (-1, key, 1.0, 1.0)
1272
- ) # use k=-1 initially so that cond_fn returns True
1273
- return ret[0]
1267
+ Hormann, "The Generation of Binonmial Random Variates"
1268
+ (https://core.ac.uk/download/pdf/11007254.pdf)
1269
+ """
1270
+
1271
+ def _btrs_body_fn(val):
1272
+ _, key, _, _ = val
1273
+ key, key_u, key_v = jr.split(key, 3)
1274
+ u = jr.uniform(key_u)
1275
+ v = jr.uniform(key_v)
1276
+ u = u - 0.5
1277
+ k = jnp.floor(
1278
+ (2 * tr_params.a / (0.5 - jnp.abs(u)) + tr_params.b) * u + tr_params.c
1279
+ ).astype(n.dtype)
1280
+ return k, key, u, v
1281
+
1282
+ def _btrs_cond_fn(val):
1283
+ def accept_fn(k, u, v):
1284
+ # See acceptance condition in Step 3. (Page 3) of TRS algorithm
1285
+ # v <= f(k) * g_grad(u) / alpha
1286
+
1287
+ m = tr_params.m
1288
+ log_p = tr_params.log_p
1289
+ log1_p = tr_params.log1_p
1290
+ # See: formula for log(f(k)) at bottom of Page 5.
1291
+ log_f = (
1292
+ (n + 1.0) * jnp.log((n - m + 1.0) / (n - k + 1.0))
1293
+ + (k + 0.5) * (jnp.log((n - k + 1.0) / (k + 1.0)) + log_p - log1_p)
1294
+ + (_stirling_approx_tail(k) - _stirling_approx_tail(n - k))
1295
+ + tr_params.log_h
1296
+ )
1297
+ g = (tr_params.a / (0.5 - jnp.abs(u)) ** 2) + tr_params.b
1298
+ return jnp.log((v * tr_params.alpha) / g) <= log_f
1299
+
1300
+ k, key, u, v = val
1301
+ early_accept = (jnp.abs(u) <= tr_params.u_r) & (v <= tr_params.v_r)
1302
+ early_reject = (k < 0) | (k > n)
1303
+ return lax.cond(
1304
+ early_accept | early_reject,
1305
+ (),
1306
+ lambda _: ~early_accept,
1307
+ (k, u, v),
1308
+ lambda x: ~accept_fn(*x),
1309
+ )
1310
+
1311
+ tr_params = _get_tr_params(n, p)
1312
+ ret = lax.while_loop(
1313
+ _btrs_cond_fn, _btrs_body_fn, (-1, key, 1.0, 1.0)
1314
+ ) # use k=-1 initially so that cond_fn returns True
1315
+ return ret[0]
1274
1316
 
1275
1317
 
1276
1318
  def _binomial_inversion(key, p, n):
1277
- def _binom_inv_body_fn(val):
1278
- i, key, geom_acc = val
1279
- key, key_u = jr.split(key)
1280
- u = jr.uniform(key_u)
1281
- geom = jnp.floor(jnp.log1p(-u) / log1_p) + 1
1282
- geom_acc = geom_acc + geom
1283
- return i + 1, key, geom_acc
1319
+ def _binom_inv_body_fn(val):
1320
+ i, key, geom_acc = val
1321
+ key, key_u = jr.split(key)
1322
+ u = jr.uniform(key_u)
1323
+ geom = jnp.floor(jnp.log1p(-u) / log1_p) + 1
1324
+ geom_acc = geom_acc + geom
1325
+ return i + 1, key, geom_acc
1284
1326
 
1285
- def _binom_inv_cond_fn(val):
1286
- i, _, geom_acc = val
1287
- return geom_acc <= n
1327
+ def _binom_inv_cond_fn(val):
1328
+ i, _, geom_acc = val
1329
+ return geom_acc <= n
1288
1330
 
1289
- log1_p = jnp.log1p(-p)
1290
- ret = lax.while_loop(_binom_inv_cond_fn, _binom_inv_body_fn, (-1, key, 0.0))
1291
- return ret[0]
1331
+ log1_p = jnp.log1p(-p)
1332
+ ret = lax.while_loop(_binom_inv_cond_fn, _binom_inv_body_fn, (-1, key, 0.0))
1333
+ return ret[0]
1292
1334
 
1293
1335
 
1294
1336
  def _binomial_dispatch(key, p, n):
1295
- def dispatch(key, p, n):
1296
- is_le_mid = p <= 0.5
1297
- pq = jnp.where(is_le_mid, p, 1 - p)
1298
- mu = n * pq
1299
- k = lax.cond(
1300
- mu < 10,
1301
- (key, pq, n),
1302
- lambda x: _binomial_inversion(*x),
1303
- (key, pq, n),
1304
- lambda x: _binomial_btrs(*x),
1337
+ def dispatch(key, p, n):
1338
+ is_le_mid = p <= 0.5
1339
+ pq = jnp.where(is_le_mid, p, 1 - p)
1340
+ mu = n * pq
1341
+ k = lax.cond(
1342
+ mu < 10,
1343
+ (key, pq, n),
1344
+ lambda x: _binomial_inversion(*x),
1345
+ (key, pq, n),
1346
+ lambda x: _binomial_btrs(*x),
1347
+ )
1348
+ return jnp.where(is_le_mid, k, n - k)
1349
+
1350
+ # Return 0 for nan `p` or negative `n`, since nan values are not allowed for integer types
1351
+ cond0 = jnp.isfinite(p) & (n > 0) & (p > 0)
1352
+ return lax.cond(
1353
+ cond0 & (p < 1),
1354
+ (key, p, n),
1355
+ lambda x: dispatch(*x),
1356
+ (),
1357
+ lambda _: jnp.where(cond0, n, 0),
1305
1358
  )
1306
- return jnp.where(is_le_mid, k, n - k)
1307
-
1308
- # Return 0 for nan `p` or negative `n`, since nan values are not allowed for integer types
1309
- cond0 = jnp.isfinite(p) & (n > 0) & (p > 0)
1310
- return lax.cond(
1311
- cond0 & (p < 1),
1312
- (key, p, n),
1313
- lambda x: dispatch(*x),
1314
- (),
1315
- lambda _: jnp.where(cond0, n, 0),
1316
- )
1317
1359
 
1318
1360
 
1319
1361
  @partial(jit, static_argnums=(3,))
1320
1362
  def _binomial(key, p, n, shape):
1321
- shape = shape or lax.broadcast_shapes(jnp.shape(p), jnp.shape(n))
1322
- # reshape to map over axis 0
1323
- p = jnp.reshape(jnp.broadcast_to(p, shape), -1)
1324
- n = jnp.reshape(jnp.broadcast_to(n, shape), -1)
1325
- key = jr.split(key, jnp.size(p))
1326
- if jax.default_backend() == "cpu":
1327
- ret = lax.map(lambda x: _binomial_dispatch(*x), (key, p, n))
1328
- else:
1329
- ret = vmap(lambda *x: _binomial_dispatch(*x))(key, p, n)
1330
- return jnp.reshape(ret, shape)
1363
+ shape = shape or lax.broadcast_shapes(jnp.shape(p), jnp.shape(n))
1364
+ # reshape to map over axis 0
1365
+ p = jnp.reshape(jnp.broadcast_to(p, shape), -1)
1366
+ n = jnp.reshape(jnp.broadcast_to(n, shape), -1)
1367
+ key = jr.split(key, jnp.size(p))
1368
+ if jax.default_backend() == "cpu":
1369
+ ret = lax.map(lambda x: _binomial_dispatch(*x), (key, p, n))
1370
+ else:
1371
+ ret = vmap(lambda *x: _binomial_dispatch(*x))(key, p, n)
1372
+ return jnp.reshape(ret, shape)
1331
1373
 
1332
1374
 
1333
1375
  @partial(jit, static_argnums=(2,))
1334
1376
  def _categorical(key, p, shape):
1335
- # this implementation is fast when event shape is small, and slow otherwise
1336
- # Ref: https://stackoverflow.com/a/34190035
1337
- shape = shape or p.shape[:-1]
1338
- s = jnp.cumsum(p, axis=-1)
1339
- r = jr.uniform(key, shape=shape + (1,))
1340
- return jnp.sum(s < r, axis=-1)
1377
+ # this implementation is fast when event shape is small, and slow otherwise
1378
+ # Ref: https://stackoverflow.com/a/34190035
1379
+ shape = shape or p.shape[:-1]
1380
+ s = jnp.cumsum(p, axis=-1)
1381
+ r = jr.uniform(key, shape=shape + (1,))
1382
+ return jnp.sum(s < r, axis=-1)
1341
1383
 
1342
1384
 
1343
1385
  def _scatter_add_one(operand, indices, updates):
1344
- return lax.scatter_add(
1345
- operand,
1346
- indices,
1347
- updates,
1348
- lax.ScatterDimensionNumbers(
1349
- update_window_dims=(),
1350
- inserted_window_dims=(0,),
1351
- scatter_dims_to_operand_dims=(0,),
1352
- ),
1353
- )
1386
+ return lax.scatter_add(
1387
+ operand,
1388
+ indices,
1389
+ updates,
1390
+ lax.ScatterDimensionNumbers(
1391
+ update_window_dims=(),
1392
+ inserted_window_dims=(0,),
1393
+ scatter_dims_to_operand_dims=(0,),
1394
+ ),
1395
+ )
1354
1396
 
1355
1397
 
1356
1398
  def _reshape(x, shape):
1357
- if isinstance(x, (int, float, np.ndarray, np.generic)):
1358
- return np.reshape(x, shape)
1359
- else:
1360
- return jnp.reshape(x, shape)
1399
+ if isinstance(x, (int, float, np.ndarray, np.generic)):
1400
+ return np.reshape(x, shape)
1401
+ else:
1402
+ return jnp.reshape(x, shape)
1361
1403
 
1362
1404
 
1363
1405
  def _promote_shapes(*args, shape=()):
1364
- # adapted from lax.lax_numpy
1365
- if len(args) < 2 and not shape:
1366
- return args
1367
- else:
1368
- shapes = [jnp.shape(arg) for arg in args]
1369
- num_dims = len(lax.broadcast_shapes(shape, *shapes))
1370
- return [
1371
- _reshape(arg, (1,) * (num_dims - len(s)) + s) if len(s) < num_dims else arg
1372
- for arg, s in zip(args, shapes)
1373
- ]
1406
+ # adapted from lax.lax_numpy
1407
+ if len(args) < 2 and not shape:
1408
+ return args
1409
+ else:
1410
+ shapes = [jnp.shape(arg) for arg in args]
1411
+ num_dims = len(lax.broadcast_shapes(shape, *shapes))
1412
+ return [
1413
+ _reshape(arg, (1,) * (num_dims - len(s)) + s) if len(s) < num_dims else arg
1414
+ for arg, s in zip(args, shapes)
1415
+ ]
1374
1416
 
1375
1417
 
1376
1418
  @partial(jit, static_argnums=(3, 4))
1377
1419
  def _multinomial(key, p, n, n_max, shape=()):
1378
- if jnp.shape(n) != jnp.shape(p)[:-1]:
1379
- broadcast_shape = lax.broadcast_shapes(jnp.shape(n), jnp.shape(p)[:-1])
1380
- n = jnp.broadcast_to(n, broadcast_shape)
1381
- p = jnp.broadcast_to(p, broadcast_shape + jnp.shape(p)[-1:])
1382
- shape = shape or p.shape[:-1]
1383
- if n_max == 0:
1384
- return jnp.zeros(shape + p.shape[-1:], dtype=jnp.result_type(int))
1385
- # get indices from categorical distribution then gather the result
1386
- indices = _categorical(key, p, (n_max,) + shape)
1387
- # mask out values when counts is heterogeneous
1388
- if jnp.ndim(n) > 0:
1389
- mask = _promote_shapes(jnp.arange(n_max) < jnp.expand_dims(n, -1), shape=shape + (n_max,))[0]
1390
- mask = jnp.moveaxis(mask, -1, 0).astype(indices.dtype)
1391
- excess = jnp.concatenate([jnp.expand_dims(n_max - n, -1),
1392
- jnp.zeros(jnp.shape(n) + (p.shape[-1] - 1,))],
1393
- -1)
1394
- else:
1395
- mask = 1
1396
- excess = 0
1397
- # NB: we transpose to move batch shape to the front
1398
- indices_2D = (jnp.reshape(indices * mask, (n_max, -1))).T
1399
- samples_2D = vmap(_scatter_add_one)(jnp.zeros((indices_2D.shape[0], p.shape[-1]), dtype=indices.dtype),
1400
- jnp.expand_dims(indices_2D, axis=-1),
1401
- jnp.ones(indices_2D.shape, dtype=indices.dtype))
1402
- return jnp.reshape(samples_2D, shape + p.shape[-1:]) - excess
1420
+ if jnp.shape(n) != jnp.shape(p)[:-1]:
1421
+ broadcast_shape = lax.broadcast_shapes(jnp.shape(n), jnp.shape(p)[:-1])
1422
+ n = jnp.broadcast_to(n, broadcast_shape)
1423
+ p = jnp.broadcast_to(p, broadcast_shape + jnp.shape(p)[-1:])
1424
+ shape = shape or p.shape[:-1]
1425
+ if n_max == 0:
1426
+ return jnp.zeros(shape + p.shape[-1:], dtype=jnp.result_type(int))
1427
+ # get indices from categorical distribution then gather the result
1428
+ indices = _categorical(key, p, (n_max,) + shape)
1429
+ # mask out values when counts is heterogeneous
1430
+ if jnp.ndim(n) > 0:
1431
+ mask = _promote_shapes(jnp.arange(n_max) < jnp.expand_dims(n, -1), shape=shape + (n_max,))[0]
1432
+ mask = jnp.moveaxis(mask, -1, 0).astype(indices.dtype)
1433
+ excess = jnp.concatenate([jnp.expand_dims(n_max - n, -1),
1434
+ jnp.zeros(jnp.shape(n) + (p.shape[-1] - 1,))],
1435
+ -1)
1436
+ else:
1437
+ mask = 1
1438
+ excess = 0
1439
+ # NB: we transpose to move batch shape to the front
1440
+ indices_2D = (jnp.reshape(indices * mask, (n_max, -1))).T
1441
+ samples_2D = vmap(_scatter_add_one)(jnp.zeros((indices_2D.shape[0], p.shape[-1]), dtype=indices.dtype),
1442
+ jnp.expand_dims(indices_2D, axis=-1),
1443
+ jnp.ones(indices_2D.shape, dtype=indices.dtype))
1444
+ return jnp.reshape(samples_2D, shape + p.shape[-1:]) - excess
1403
1445
 
1404
1446
 
1405
1447
  @partial(jit, static_argnums=(2, 3), static_argnames=['shape', 'dtype'])
1406
1448
  def _von_mises_centered(key, concentration, shape, dtype=None):
1407
- """Compute centered von Mises samples using rejection sampling from [1]_ with wrapped Cauchy proposal.
1408
-
1409
- Returns
1410
- -------
1411
- out: array_like
1412
- centered samples from von Mises
1413
-
1414
- References
1415
- ----------
1416
- .. [1] Luc Devroye "Non-Uniform Random Variate Generation", Springer-Verlag, 1986;
1417
- Chapter 9, p. 473-476. http://www.nrbook.com/devroye/Devroye_files/chapter_nine.pdf
1418
-
1419
- """
1420
- shape = shape or jnp.shape(concentration)
1421
- dtype = dtype or environ.dftype()
1422
- concentration = lax.convert_element_type(concentration, dtype)
1423
- concentration = jnp.broadcast_to(concentration, shape)
1424
-
1425
- if dtype == jnp.float16:
1426
- s_cutoff = 1.8e-1
1427
- elif dtype == jnp.float32:
1428
- s_cutoff = 2e-2
1429
- elif dtype == jnp.float64:
1430
- s_cutoff = 1.2e-4
1431
- else:
1432
- raise ValueError(f"Unsupported dtype: {dtype}")
1433
-
1434
- r = 1.0 + jnp.sqrt(1.0 + 4.0 * concentration ** 2)
1435
- rho = (r - jnp.sqrt(2.0 * r)) / (2.0 * concentration)
1436
- s_exact = (1.0 + rho ** 2) / (2.0 * rho)
1437
-
1438
- s_approximate = 1.0 / concentration
1439
-
1440
- s = jnp.where(concentration > s_cutoff, s_exact, s_approximate)
1441
-
1442
- def cond_fn(*args):
1443
- """check if all are done or reached max number of iterations"""
1444
- i, _, done, _, _ = args[0]
1445
- return jnp.bitwise_and(i < 100, jnp.logical_not(jnp.all(done)))
1446
-
1447
- def body_fn(*args):
1448
- i, key, done, _, w = args[0]
1449
- uni_ukey, uni_vkey, key = jr.split(key, 3)
1450
- u = jr.uniform(
1451
- key=uni_ukey,
1452
- shape=shape,
1453
- dtype=concentration.dtype,
1454
- minval=-1.0,
1455
- maxval=1.0,
1456
- )
1457
- z = jnp.cos(jnp.pi * u)
1458
- w = jnp.where(done, w, (1.0 + s * z) / (s + z)) # Update where not done
1459
- y = concentration * (s - w)
1460
- v = jr.uniform(key=uni_vkey, shape=shape, dtype=concentration.dtype)
1461
- accept = (y * (2.0 - y) >= v) | (jnp.log(y / v) + 1.0 >= y)
1462
- return i + 1, key, accept | done, u, w
1449
+ """Compute centered von Mises samples using rejection sampling from [1]_ with wrapped Cauchy proposal.
1463
1450
 
1464
- init_done = jnp.zeros(shape, dtype=bool)
1465
- init_u = jnp.zeros(shape)
1466
- init_w = jnp.zeros(shape)
1451
+ Returns
1452
+ -------
1453
+ out: array_like
1454
+ centered samples from von Mises
1467
1455
 
1468
- _, _, done, u, w = lax.while_loop(
1469
- cond_fun=cond_fn,
1470
- body_fun=body_fn,
1471
- init_val=(jnp.array(0), key, init_done, init_u, init_w),
1472
- )
1456
+ References
1457
+ ----------
1458
+ .. [1] Luc Devroye "Non-Uniform Random Variate Generation", Springer-Verlag, 1986;
1459
+ Chapter 9, p. 473-476. http://www.nrbook.com/devroye/Devroye_files/chapter_nine.pdf
1473
1460
 
1474
- return jnp.sign(u) * jnp.arccos(w)
1461
+ """
1462
+ shape = shape or jnp.shape(concentration)
1463
+ dtype = dtype or environ.dftype()
1464
+ concentration = lax.convert_element_type(concentration, dtype)
1465
+ concentration = jnp.broadcast_to(concentration, shape)
1466
+
1467
+ if dtype == jnp.float16:
1468
+ s_cutoff = 1.8e-1
1469
+ elif dtype == jnp.float32:
1470
+ s_cutoff = 2e-2
1471
+ elif dtype == jnp.float64:
1472
+ s_cutoff = 1.2e-4
1473
+ else:
1474
+ raise ValueError(f"Unsupported dtype: {dtype}")
1475
+
1476
+ r = 1.0 + jnp.sqrt(1.0 + 4.0 * concentration ** 2)
1477
+ rho = (r - jnp.sqrt(2.0 * r)) / (2.0 * concentration)
1478
+ s_exact = (1.0 + rho ** 2) / (2.0 * rho)
1479
+
1480
+ s_approximate = 1.0 / concentration
1481
+
1482
+ s = jnp.where(concentration > s_cutoff, s_exact, s_approximate)
1483
+
1484
+ def cond_fn(*args):
1485
+ """check if all are done or reached max number of iterations"""
1486
+ i, _, done, _, _ = args[0]
1487
+ return jnp.bitwise_and(i < 100, jnp.logical_not(jnp.all(done)))
1488
+
1489
+ def body_fn(*args):
1490
+ i, key, done, _, w = args[0]
1491
+ uni_ukey, uni_vkey, key = jr.split(key, 3)
1492
+ u = jr.uniform(
1493
+ key=uni_ukey,
1494
+ shape=shape,
1495
+ dtype=concentration.dtype,
1496
+ minval=-1.0,
1497
+ maxval=1.0,
1498
+ )
1499
+ z = jnp.cos(jnp.pi * u)
1500
+ w = jnp.where(done, w, (1.0 + s * z) / (s + z)) # Update where not done
1501
+ y = concentration * (s - w)
1502
+ v = jr.uniform(key=uni_vkey, shape=shape, dtype=concentration.dtype)
1503
+ accept = (y * (2.0 - y) >= v) | (jnp.log(y / v) + 1.0 >= y)
1504
+ return i + 1, key, accept | done, u, w
1505
+
1506
+ init_done = jnp.zeros(shape, dtype=bool)
1507
+ init_u = jnp.zeros(shape)
1508
+ init_w = jnp.zeros(shape)
1509
+
1510
+ _, _, done, u, w = lax.while_loop(
1511
+ cond_fun=cond_fn,
1512
+ body_fun=body_fn,
1513
+ init_val=(jnp.array(0), key, init_done, init_u, init_w),
1514
+ )
1515
+
1516
+ return jnp.sign(u) * jnp.arccos(w)
1475
1517
 
1476
1518
 
1477
1519
  def _loc_scale(loc, scale, value):
1478
- if loc is None:
1479
- if scale is None:
1480
- return value
1481
- else:
1482
- return value * scale
1483
- else:
1484
- if scale is None:
1485
- return value + loc
1520
+ if loc is None:
1521
+ if scale is None:
1522
+ return value
1523
+ else:
1524
+ return value * scale
1486
1525
  else:
1487
- return value * scale + loc
1526
+ if scale is None:
1527
+ return value + loc
1528
+ else:
1529
+ return value * scale + loc
1488
1530
 
1489
1531
 
1490
1532
  def _check_py_seq(seq):
1491
- return jnp.asarray(seq) if isinstance(seq, (tuple, list)) else seq
1533
+ return jnp.asarray(seq) if isinstance(seq, (tuple, list)) else seq