brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0.post20241122__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (184) hide show
  1. benchmark/COBA_2005.py +125 -0
  2. benchmark/CUBA_2005.py +149 -0
  3. brainstate/__init__.py +31 -11
  4. brainstate/_state.py +760 -316
  5. brainstate/_state_test.py +41 -12
  6. brainstate/_utils.py +31 -4
  7. brainstate/augment/__init__.py +40 -0
  8. brainstate/augment/_autograd.py +611 -0
  9. brainstate/augment/_autograd_test.py +1193 -0
  10. brainstate/augment/_eval_shape.py +102 -0
  11. brainstate/augment/_eval_shape_test.py +40 -0
  12. brainstate/augment/_mapping.py +525 -0
  13. brainstate/augment/_mapping_test.py +210 -0
  14. brainstate/augment/_random.py +99 -0
  15. brainstate/{transform → compile}/__init__.py +25 -13
  16. brainstate/compile/_ad_checkpoint.py +204 -0
  17. brainstate/compile/_ad_checkpoint_test.py +51 -0
  18. brainstate/compile/_conditions.py +259 -0
  19. brainstate/compile/_conditions_test.py +221 -0
  20. brainstate/compile/_error_if.py +94 -0
  21. brainstate/compile/_error_if_test.py +54 -0
  22. brainstate/compile/_jit.py +314 -0
  23. brainstate/compile/_jit_test.py +143 -0
  24. brainstate/compile/_loop_collect_return.py +516 -0
  25. brainstate/compile/_loop_collect_return_test.py +59 -0
  26. brainstate/compile/_loop_no_collection.py +185 -0
  27. brainstate/compile/_loop_no_collection_test.py +51 -0
  28. brainstate/compile/_make_jaxpr.py +756 -0
  29. brainstate/compile/_make_jaxpr_test.py +134 -0
  30. brainstate/compile/_progress_bar.py +111 -0
  31. brainstate/compile/_unvmap.py +159 -0
  32. brainstate/compile/_util.py +147 -0
  33. brainstate/environ.py +408 -381
  34. brainstate/environ_test.py +34 -32
  35. brainstate/event/__init__.py +27 -0
  36. brainstate/event/_csr.py +316 -0
  37. brainstate/event/_csr_benchmark.py +14 -0
  38. brainstate/event/_csr_test.py +118 -0
  39. brainstate/event/_fixed_probability.py +708 -0
  40. brainstate/event/_fixed_probability_benchmark.py +128 -0
  41. brainstate/event/_fixed_probability_test.py +131 -0
  42. brainstate/event/_linear.py +359 -0
  43. brainstate/event/_linear_benckmark.py +82 -0
  44. brainstate/event/_linear_test.py +117 -0
  45. brainstate/{nn/event → event}/_misc.py +7 -7
  46. brainstate/event/_xla_custom_op.py +312 -0
  47. brainstate/event/_xla_custom_op_test.py +55 -0
  48. brainstate/functional/_activations.py +521 -511
  49. brainstate/functional/_activations_test.py +300 -300
  50. brainstate/functional/_normalization.py +43 -43
  51. brainstate/functional/_others.py +15 -15
  52. brainstate/functional/_spikes.py +49 -49
  53. brainstate/graph/__init__.py +33 -0
  54. brainstate/graph/_graph_context.py +443 -0
  55. brainstate/graph/_graph_context_test.py +65 -0
  56. brainstate/graph/_graph_convert.py +246 -0
  57. brainstate/graph/_graph_node.py +300 -0
  58. brainstate/graph/_graph_node_test.py +75 -0
  59. brainstate/graph/_graph_operation.py +1746 -0
  60. brainstate/graph/_graph_operation_test.py +724 -0
  61. brainstate/init/_base.py +28 -10
  62. brainstate/init/_generic.py +175 -172
  63. brainstate/init/_random_inits.py +470 -415
  64. brainstate/init/_random_inits_test.py +150 -0
  65. brainstate/init/_regular_inits.py +66 -69
  66. brainstate/init/_regular_inits_test.py +51 -0
  67. brainstate/mixin.py +236 -244
  68. brainstate/mixin_test.py +44 -46
  69. brainstate/nn/__init__.py +26 -51
  70. brainstate/nn/_collective_ops.py +199 -0
  71. brainstate/nn/_dyn_impl/__init__.py +46 -0
  72. brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
  73. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
  74. brainstate/nn/_dyn_impl/_dynamics_synapse.py +315 -0
  75. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
  76. brainstate/nn/_dyn_impl/_inputs.py +154 -0
  77. brainstate/nn/{event/__init__.py → _dyn_impl/_projection_alignpost.py} +8 -8
  78. brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
  79. brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
  80. brainstate/nn/_dyn_impl/_readout.py +128 -0
  81. brainstate/nn/_dyn_impl/_readout_test.py +54 -0
  82. brainstate/nn/_dynamics/__init__.py +37 -0
  83. brainstate/nn/_dynamics/_dynamics_base.py +631 -0
  84. brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
  85. brainstate/nn/_dynamics/_projection_base.py +346 -0
  86. brainstate/nn/_dynamics/_state_delay.py +453 -0
  87. brainstate/nn/_dynamics/_synouts.py +161 -0
  88. brainstate/nn/_dynamics/_synouts_test.py +58 -0
  89. brainstate/nn/_elementwise/__init__.py +22 -0
  90. brainstate/nn/_elementwise/_dropout.py +418 -0
  91. brainstate/nn/_elementwise/_dropout_test.py +100 -0
  92. brainstate/nn/_elementwise/_elementwise.py +1122 -0
  93. brainstate/nn/_elementwise/_elementwise_test.py +171 -0
  94. brainstate/nn/_exp_euler.py +97 -0
  95. brainstate/nn/_exp_euler_test.py +36 -0
  96. brainstate/nn/_interaction/__init__.py +41 -0
  97. brainstate/nn/_interaction/_conv.py +499 -0
  98. brainstate/nn/_interaction/_conv_test.py +239 -0
  99. brainstate/nn/_interaction/_embedding.py +59 -0
  100. brainstate/nn/_interaction/_linear.py +582 -0
  101. brainstate/nn/_interaction/_linear_test.py +42 -0
  102. brainstate/nn/_interaction/_normalizations.py +388 -0
  103. brainstate/nn/_interaction/_normalizations_test.py +75 -0
  104. brainstate/nn/_interaction/_poolings.py +1179 -0
  105. brainstate/nn/_interaction/_poolings_test.py +219 -0
  106. brainstate/nn/_module.py +328 -0
  107. brainstate/nn/_module_test.py +211 -0
  108. brainstate/nn/metrics.py +309 -309
  109. brainstate/optim/__init__.py +14 -2
  110. brainstate/optim/_base.py +66 -0
  111. brainstate/optim/_lr_scheduler.py +363 -400
  112. brainstate/optim/_lr_scheduler_test.py +25 -24
  113. brainstate/optim/_optax_optimizer.py +121 -176
  114. brainstate/optim/_optax_optimizer_test.py +41 -1
  115. brainstate/optim/_sgd_optimizer.py +950 -1025
  116. brainstate/random/_rand_funs.py +3269 -3268
  117. brainstate/random/_rand_funs_test.py +568 -0
  118. brainstate/random/_rand_seed.py +149 -117
  119. brainstate/random/_rand_seed_test.py +50 -0
  120. brainstate/random/_rand_state.py +1356 -1321
  121. brainstate/random/_random_for_unit.py +13 -13
  122. brainstate/surrogate.py +1262 -1243
  123. brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
  124. brainstate/typing.py +157 -130
  125. brainstate/util/__init__.py +52 -0
  126. brainstate/util/_caller.py +100 -0
  127. brainstate/util/_dict.py +734 -0
  128. brainstate/util/_dict_test.py +160 -0
  129. brainstate/{nn/_projection/__init__.py → util/_error.py} +9 -13
  130. brainstate/util/_filter.py +178 -0
  131. brainstate/util/_others.py +497 -0
  132. brainstate/util/_pretty_repr.py +208 -0
  133. brainstate/util/_scaling.py +260 -0
  134. brainstate/util/_struct.py +524 -0
  135. brainstate/util/_tracers.py +75 -0
  136. brainstate/{_visualization.py → util/_visualization.py} +16 -16
  137. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/METADATA +11 -11
  138. brainstate-0.1.0.post20241122.dist-info/RECORD +144 -0
  139. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/top_level.txt +1 -0
  140. brainstate/_module.py +0 -1637
  141. brainstate/_module_test.py +0 -207
  142. brainstate/nn/_base.py +0 -251
  143. brainstate/nn/_connections.py +0 -686
  144. brainstate/nn/_dynamics.py +0 -426
  145. brainstate/nn/_elementwise.py +0 -1438
  146. brainstate/nn/_embedding.py +0 -66
  147. brainstate/nn/_misc.py +0 -133
  148. brainstate/nn/_normalizations.py +0 -389
  149. brainstate/nn/_others.py +0 -101
  150. brainstate/nn/_poolings.py +0 -1229
  151. brainstate/nn/_poolings_test.py +0 -231
  152. brainstate/nn/_projection/_align_post.py +0 -546
  153. brainstate/nn/_projection/_align_pre.py +0 -599
  154. brainstate/nn/_projection/_delta.py +0 -241
  155. brainstate/nn/_projection/_vanilla.py +0 -101
  156. brainstate/nn/_rate_rnns.py +0 -410
  157. brainstate/nn/_readout.py +0 -136
  158. brainstate/nn/_synouts.py +0 -166
  159. brainstate/nn/event/csr.py +0 -312
  160. brainstate/nn/event/csr_test.py +0 -118
  161. brainstate/nn/event/fixed_probability.py +0 -276
  162. brainstate/nn/event/fixed_probability_test.py +0 -127
  163. brainstate/nn/event/linear.py +0 -220
  164. brainstate/nn/event/linear_test.py +0 -111
  165. brainstate/random/random_test.py +0 -593
  166. brainstate/transform/_autograd.py +0 -585
  167. brainstate/transform/_autograd_test.py +0 -1181
  168. brainstate/transform/_conditions.py +0 -334
  169. brainstate/transform/_conditions_test.py +0 -220
  170. brainstate/transform/_error_if.py +0 -94
  171. brainstate/transform/_error_if_test.py +0 -55
  172. brainstate/transform/_jit.py +0 -265
  173. brainstate/transform/_jit_test.py +0 -118
  174. brainstate/transform/_loop_collect_return.py +0 -502
  175. brainstate/transform/_loop_no_collection.py +0 -170
  176. brainstate/transform/_make_jaxpr.py +0 -739
  177. brainstate/transform/_make_jaxpr_test.py +0 -131
  178. brainstate/transform/_mapping.py +0 -109
  179. brainstate/transform/_progress_bar.py +0 -111
  180. brainstate/transform/_unvmap.py +0 -143
  181. brainstate/util.py +0 -746
  182. brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
  183. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/LICENSE +0 -0
  184. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/WHEEL +0 -0
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
 
16
16
  # -*- coding: utf-8 -*-
17
+ from __future__ import annotations
17
18
 
18
19
  from collections import namedtuple
19
20
  from functools import partial
@@ -30,7 +31,7 @@ from jax import lax, core, dtypes
30
31
 
31
32
  from brainstate import environ
32
33
  from brainstate._state import State
33
- from brainstate.transform._error_if import jit_error_if
34
+ from brainstate.compile._error_if import jit_error_if
34
35
  from brainstate.typing import DTypeLike, Size, SeedOrKey
35
36
  from ._random_for_unit import uniform_for_unit, permutation_for_unit
36
37
 
@@ -38,1064 +39,1098 @@ __all__ = ['RandomState', 'DEFAULT', ]
38
39
 
39
40
 
40
41
  class RandomState(State):
41
- """RandomState that track the random generator state. """
42
- __slots__ = ()
43
-
44
- def __init__(self, seed_or_key: Optional[SeedOrKey] = None):
45
- """RandomState constructor.
46
-
47
- Parameters
48
- ----------
49
- seed_or_key: int, Array, optional
50
- It can be an integer for initial seed of the random number generator,
51
- or it can be a JAX's PRNKey, which is an array with two elements and `uint32` dtype.
52
- """
53
- with jax.ensure_compile_time_eval():
54
- if seed_or_key is None:
55
- seed_or_key = np.random.randint(0, 100000, 2, dtype=np.uint32)
56
- if isinstance(seed_or_key, int):
57
- key = jr.PRNGKey(seed_or_key)
58
- else:
59
- if len(seed_or_key) != 2 and seed_or_key.dtype != np.uint32:
60
- raise ValueError('key must be an array with dtype uint32. '
61
- f'But we got {seed_or_key}')
62
- key = seed_or_key
63
- super().__init__(key)
64
-
65
- def __repr__(self) -> str:
66
- print_code = repr(self.value)
67
- i = print_code.index('(')
68
- return f'{self.__class__.__name__}(key={print_code[i:]})'
69
-
70
- def _check_if_deleted(self):
71
- if isinstance(self._value, jax.Array) and not isinstance(self._value, jax.core.Tracer) and self._value.is_deleted():
72
- self.seed()
73
-
74
- # ------------------- #
75
- # seed and random key #
76
- # ------------------- #
77
-
78
- def clone(self):
79
- return type(self)(self.split_key())
80
-
81
- def set_key(self, key: SeedOrKey):
82
- self.value = key
83
-
84
-
85
- def seed(self, seed_or_key: Optional[SeedOrKey] = None):
86
- """Sets a new random seed.
87
-
88
- Parameters
89
- ----------
90
- seed_or_key: int, ArrayLike, optional
91
- It can be an integer for initial seed of the random number generator,
92
- or it can be a JAX's PRNKey, which is an array with two elements and `uint32` dtype.
93
- """
94
- with jax.ensure_compile_time_eval():
95
- if seed_or_key is None:
96
- seed_or_key = np.random.randint(0, 100000, 2, dtype=np.uint32)
97
- if np.size(seed_or_key) == 1:
98
- key = jr.PRNGKey(seed_or_key)
99
- else:
100
- if len(seed_or_key) != 2 and seed_or_key.dtype != np.uint32:
101
- raise ValueError('key must be an array with dtype uint32. '
102
- f'But we got {seed_or_key}')
103
- key = seed_or_key
104
- self.value = key
105
-
106
- def split_key(self):
107
- """Create a new seed from the current seed.
108
- """
109
- if not isinstance(self.value, jax.Array):
110
- self.value = jnp.asarray(self.value, dtype=jnp.uint32)
111
- keys = jr.split(self.value, num=2)
112
- self.value = keys[0]
113
- return keys[1]
114
-
115
- def split_keys(self, n: int):
116
- """Create multiple seeds from the current seed. This is used
117
- internally by `pmap` and `vmap` to ensure that random numbers
118
- are different in parallel threads.
119
-
120
- Parameters
121
- ----------
122
- n : int
123
- The number of seeds to generate.
124
- """
125
- keys = jr.split(self.value, n + 1)
126
- self.value = keys[0]
127
- return keys[1:]
128
-
129
- # ---------------- #
130
- # random functions #
131
- # ---------------- #
42
+ """RandomState that track the random generator state. """
43
+
44
+ # __slots__ = ('_backup', '_value')
45
+
46
+ def __init__(self, seed_or_key: Optional[SeedOrKey] = None):
47
+ """RandomState constructor.
48
+
49
+ Parameters
50
+ ----------
51
+ seed_or_key: int, Array, optional
52
+ It can be an integer for initial seed of the random number generator,
53
+ or it can be a JAX's PRNKey, which is an array with two elements and `uint32` dtype.
54
+ """
55
+ with jax.ensure_compile_time_eval():
56
+ if seed_or_key is None:
57
+ seed_or_key = np.random.randint(0, 100000, 2, dtype=np.uint32)
58
+ if isinstance(seed_or_key, int):
59
+ key = jr.PRNGKey(seed_or_key)
60
+ else:
61
+ if len(seed_or_key) != 2 and seed_or_key.dtype != np.uint32:
62
+ raise ValueError('key must be an array with dtype uint32. '
63
+ f'But we got {seed_or_key}')
64
+ key = seed_or_key
65
+ super().__init__(key)
66
+
67
+ self._backup = None
68
+
69
+ def __repr__(self):
70
+ return f'{self.__class__.__name__}({self.value})'
71
+
72
+ def check_if_deleted(self):
73
+ if (
74
+ isinstance(self._value, jax.Array) and
75
+ not isinstance(self._value, jax.core.Tracer) and
76
+ self._value.is_deleted()
77
+ ):
78
+ self.seed()
79
+
80
+ # ------------------- #
81
+ # seed and random key #
82
+ # ------------------- #
83
+
84
+ def backup_key(self):
85
+ if self._backup is not None:
86
+ raise ValueError('The random key has been backed up, and has not been restored.')
87
+ self._backup = self.value
88
+
89
+ def restore_key(self):
90
+ if self._backup is None:
91
+ raise ValueError('The random key has not been backed up.')
92
+ self.value = self._backup
93
+ self._backup = None
94
+
95
+ def clone(self):
96
+ return type(self)(self.split_key())
97
+
98
+ def set_key(self, key: SeedOrKey):
99
+ self.value = key
100
+
101
+ def seed(self, seed_or_key: Optional[SeedOrKey] = None):
102
+ """Sets a new random seed.
103
+
104
+ Parameters
105
+ ----------
106
+ seed_or_key: int, ArrayLike, optional
107
+ It can be an integer for initial seed of the random number generator,
108
+ or it can be a JAX's PRNKey, which is an array with two elements and `uint32` dtype.
109
+ """
110
+ with jax.ensure_compile_time_eval():
111
+ if seed_or_key is None:
112
+ seed_or_key = np.random.randint(0, 100000, 2, dtype=np.uint32)
113
+ if np.size(seed_or_key) == 1:
114
+ key = jr.PRNGKey(seed_or_key)
115
+ else:
116
+ if len(seed_or_key) != 2 and seed_or_key.dtype != np.uint32:
117
+ raise ValueError('key must be an array with dtype uint32. '
118
+ f'But we got {seed_or_key}')
119
+ key = seed_or_key
120
+ self.value = key
121
+
122
+ def split_key(self, n: Optional[int] = None, backup: bool = False) -> SeedOrKey:
123
+ """
124
+ Create a new seed from the current seed.
125
+
126
+ Parameters
127
+ ----------
128
+ n: int, optional
129
+ The number of seeds to generate.
130
+ backup : bool, optional
131
+ Whether to backup the current key.
132
+
133
+ Returns
134
+ -------
135
+ key : SeedOrKey
136
+ The new seed or a tuple of JAX random keys.
137
+ """
138
+ if n is not None:
139
+ assert isinstance(n, int) and n >= 1, f'n should be an integer greater than 1, but we got {n}'
140
+
141
+ if not isinstance(self.value, jax.Array):
142
+ self.value = jnp.asarray(self.value, dtype=jnp.uint32)
143
+ keys = jr.split(self.value, num=2 if n is None else n + 1)
144
+ self.value = keys[0]
145
+ if backup:
146
+ self.backup_key()
147
+ if n is None:
148
+ return keys[1]
149
+ else:
150
+ return keys[1:]
151
+
152
+ def self_assign_multi_keys(self, n: int, backup: bool = True):
153
+ """
154
+ Self-assign multiple keys to the current random state.
155
+ """
156
+ if backup:
157
+ keys = jr.split(self.value, n + 1)
158
+ self.value = keys[0]
159
+ self.backup_key()
160
+ self.value = keys[1:]
161
+ else:
162
+ self.value = jr.split(self.value, n)
163
+
164
+ # ---------------- #
165
+ # random functions #
166
+ # ---------------- #
167
+
168
+ def rand(self, *dn, key: Optional[SeedOrKey] = None, dtype: DTypeLike = None):
169
+ key = self.split_key() if key is None else _formalize_key(key)
170
+ dtype = dtype or environ.dftype()
171
+ r = uniform_for_unit(key, shape=dn, minval=0., maxval=1., dtype=dtype)
172
+ return r
173
+
174
+ def randint(
175
+ self,
176
+ low,
177
+ high=None,
178
+ size: Optional[Size] = None,
179
+ dtype: DTypeLike = None,
180
+ key: Optional[SeedOrKey] = None
181
+ ):
182
+ if high is None:
183
+ high = low
184
+ low = 0
185
+ high = _check_py_seq(high)
186
+ low = _check_py_seq(low)
187
+ if size is None:
188
+ size = lax.broadcast_shapes(jnp.shape(low),
189
+ jnp.shape(high))
190
+ key = self.split_key() if key is None else _formalize_key(key)
191
+ dtype = dtype or environ.ditype()
192
+ r = jr.randint(key,
193
+ shape=_size2shape(size),
194
+ minval=low, maxval=high, dtype=dtype)
195
+ return r
196
+
197
+ def random_integers(
198
+ self,
199
+ low,
200
+ high=None,
201
+ size: Optional[Size] = None,
202
+ key: Optional[SeedOrKey] = None,
203
+ dtype: DTypeLike = None,
204
+ ):
205
+ low = _check_py_seq(low)
206
+ high = _check_py_seq(high)
207
+ if high is None:
208
+ high = low
209
+ low = 1
210
+ high += 1
211
+ if size is None:
212
+ size = lax.broadcast_shapes(jnp.shape(low), jnp.shape(high))
213
+ key = self.split_key() if key is None else _formalize_key(key)
214
+ dtype = dtype or environ.ditype()
215
+ r = jr.randint(key,
216
+ shape=_size2shape(size),
217
+ minval=low,
218
+ maxval=high,
219
+ dtype=dtype)
220
+ return r
221
+
222
+ def randn(self, *dn, key: Optional[SeedOrKey] = None, dtype: DTypeLike = None):
223
+ key = self.split_key() if key is None else _formalize_key(key)
224
+ dtype = dtype or environ.dftype()
225
+ r = jr.normal(key, shape=dn, dtype=dtype)
226
+ return r
227
+
228
+ def random(self,
229
+ size: Optional[Size] = None,
230
+ key: Optional[SeedOrKey] = None,
231
+ dtype: DTypeLike = None):
232
+ dtype = dtype or environ.dftype()
233
+ key = self.split_key() if key is None else _formalize_key(key)
234
+ r = uniform_for_unit(key, shape=_size2shape(size), minval=0., maxval=1., dtype=dtype)
235
+ return r
132
236
 
133
- def rand(self, *dn, key: Optional[SeedOrKey] = None, dtype: DTypeLike = None):
134
- key = self.split_key() if key is None else _formalize_key(key)
135
- dtype = dtype or environ.dftype()
136
- r = uniform_for_unit(key, shape=dn, minval=0., maxval=1., dtype=dtype)
137
- return r
138
-
139
- def randint(
140
- self,
141
- low,
142
- high=None,
143
- size: Optional[Size] = None,
144
- dtype: DTypeLike = None,
145
- key: Optional[SeedOrKey] = None
146
- ):
147
- if high is None:
148
- high = low
149
- low = 0
150
- high = _check_py_seq(high)
151
- low = _check_py_seq(low)
152
- if size is None:
153
- size = lax.broadcast_shapes(jnp.shape(low),
154
- jnp.shape(high))
155
- key = self.split_key() if key is None else _formalize_key(key)
156
- dtype = dtype or environ.ditype()
157
- r = jr.randint(key,
158
- shape=_size2shape(size),
159
- minval=low, maxval=high, dtype=dtype)
160
- return r
161
-
162
- def random_integers(
163
- self,
164
- low,
165
- high=None,
166
- size: Optional[Size] = None,
167
- key: Optional[SeedOrKey] = None,
168
- dtype: DTypeLike = None,
169
- ):
170
- low = _check_py_seq(low)
171
- high = _check_py_seq(high)
172
- if high is None:
173
- high = low
174
- low = 1
175
- high += 1
176
- if size is None:
177
- size = lax.broadcast_shapes(jnp.shape(low), jnp.shape(high))
178
- key = self.split_key() if key is None else _formalize_key(key)
179
- dtype = dtype or environ.ditype()
180
- r = jr.randint(key,
181
- shape=_size2shape(size),
182
- minval=low,
183
- maxval=high,
184
- dtype=dtype)
185
- return r
186
-
187
- def randn(self, *dn, key: Optional[SeedOrKey] = None, dtype: DTypeLike = None):
188
- key = self.split_key() if key is None else _formalize_key(key)
189
- dtype = dtype or environ.dftype()
190
- r = jr.normal(key, shape=dn, dtype=dtype)
191
- return r
237
+ def random_sample(self,
238
+ size: Optional[Size] = None,
239
+ key: Optional[SeedOrKey] = None,
240
+ dtype: DTypeLike = None):
241
+ r = self.random(size=size, key=key, dtype=dtype)
242
+ return r
192
243
 
193
- def random(self,
244
+ def ranf(self,
194
245
  size: Optional[Size] = None,
195
246
  key: Optional[SeedOrKey] = None,
196
247
  dtype: DTypeLike = None):
197
- dtype = dtype or environ.dftype()
198
- key = self.split_key() if key is None else _formalize_key(key)
199
- r = uniform_for_unit(key, shape=_size2shape(size), minval=0., maxval=1., dtype=dtype)
200
- return r
201
-
202
- def random_sample(self,
203
- size: Optional[Size] = None,
204
- key: Optional[SeedOrKey] = None,
205
- dtype: DTypeLike = None):
206
- r = self.random(size=size, key=key, dtype=dtype)
207
- return r
248
+ r = self.random(size=size, key=key, dtype=dtype)
249
+ return r
208
250
 
209
- def ranf(self,
210
- size: Optional[Size] = None,
211
- key: Optional[SeedOrKey] = None,
212
- dtype: DTypeLike = None):
213
- r = self.random(size=size, key=key, dtype=dtype)
214
- return r
251
+ def sample(self,
252
+ size: Optional[Size] = None,
253
+ key: Optional[SeedOrKey] = None,
254
+ dtype: DTypeLike = None):
255
+ r = self.random(size=size, key=key, dtype=dtype)
256
+ return r
215
257
 
216
- def sample(self,
217
- size: Optional[Size] = None,
218
- key: Optional[SeedOrKey] = None,
219
- dtype: DTypeLike = None):
220
- r = self.random(size=size, key=key, dtype=dtype)
221
- return r
258
+ def choice(self,
259
+ a,
260
+ size: Optional[Size] = None,
261
+ replace=True,
262
+ p=None,
263
+ key: Optional[SeedOrKey] = None):
264
+ a = _check_py_seq(a)
265
+ p = _check_py_seq(p)
266
+ key = self.split_key() if key is None else _formalize_key(key)
267
+ r = jr.choice(key, a=a, shape=_size2shape(size), replace=replace, p=p)
268
+ return r
269
+
270
+ def permutation(self,
271
+ x,
272
+ axis: int = 0,
273
+ independent: bool = False,
274
+ key: Optional[SeedOrKey] = None):
275
+ x = _check_py_seq(x)
276
+ key = self.split_key() if key is None else _formalize_key(key)
277
+ r = permutation_for_unit(key, x, axis=axis, independent=independent)
278
+ return r
279
+
280
+ def shuffle(self,
281
+ x,
282
+ axis=0,
283
+ key: Optional[SeedOrKey] = None):
284
+ key = self.split_key() if key is None else _formalize_key(key)
285
+ x = permutation_for_unit(key, x, axis=axis)
286
+ return x
222
287
 
223
- def choice(self,
288
+ def beta(self,
224
289
  a,
225
- size: Optional[Size] = None,
226
- replace=True,
227
- p=None,
228
- key: Optional[SeedOrKey] = None):
229
- a = _check_py_seq(a)
230
- p = _check_py_seq(p)
231
- key = self.split_key() if key is None else _formalize_key(key)
232
- r = jr.choice(key, a=a, shape=_size2shape(size), replace=replace, p=p)
233
- return r
234
-
235
- def permutation(self,
236
- x,
237
- axis: int = 0,
238
- independent: bool = False,
239
- key: Optional[SeedOrKey] = None):
240
- x = _check_py_seq(x)
241
- key = self.split_key() if key is None else _formalize_key(key)
242
- r = permutation_for_unit(key, x, axis=axis, independent=independent)
243
- return r
244
-
245
- def shuffle(self,
246
- x,
247
- axis=0,
248
- key: Optional[SeedOrKey] = None):
249
- key = self.split_key() if key is None else _formalize_key(key)
250
- x = permutation_for_unit(key, x, axis=axis)
251
- return x
252
-
253
- def beta(self,
254
- a,
255
- b,
256
- size: Optional[Size] = None,
257
- key: Optional[SeedOrKey] = None,
258
- dtype: DTypeLike = None):
259
- a = _check_py_seq(a)
260
- b = _check_py_seq(b)
261
- if size is None:
262
- size = lax.broadcast_shapes(jnp.shape(a), jnp.shape(b))
263
- key = self.split_key() if key is None else _formalize_key(key)
264
- dtype = dtype or environ.dftype()
265
- r = jr.beta(key, a=a, b=b, shape=_size2shape(size), dtype=dtype)
266
- return r
267
-
268
- def exponential(self,
269
- scale=None,
270
- size: Optional[Size] = None,
271
- key: Optional[SeedOrKey] = None,
272
- dtype: DTypeLike = None):
273
- if size is None:
274
- size = jnp.shape(scale)
275
- key = self.split_key() if key is None else _formalize_key(key)
276
- dtype = dtype or environ.dftype()
277
- scale = jnp.asarray(scale, dtype=dtype)
278
- r = jr.exponential(key, shape=_size2shape(size), dtype=dtype)
279
- if scale is not None:
280
- r = r / scale
281
- return r
282
-
283
- def gamma(self,
284
- shape,
285
- scale=None,
286
- size: Optional[Size] = None,
287
- key: Optional[SeedOrKey] = None,
288
- dtype: DTypeLike = None):
289
- shape = _check_py_seq(shape)
290
- scale = _check_py_seq(scale)
291
- if size is None:
292
- size = lax.broadcast_shapes(jnp.shape(shape), jnp.shape(scale))
293
- key = self.split_key() if key is None else _formalize_key(key)
294
- dtype = dtype or environ.dftype()
295
- r = jr.gamma(key, a=shape, shape=_size2shape(size), dtype=dtype)
296
- if scale is not None:
297
- r = r * scale
298
- return r
299
-
300
- def gumbel(self,
301
- loc=None,
302
- scale=None,
290
+ b,
303
291
  size: Optional[Size] = None,
304
292
  key: Optional[SeedOrKey] = None,
305
293
  dtype: DTypeLike = None):
306
- loc = _check_py_seq(loc)
307
- scale = _check_py_seq(scale)
308
- if size is None:
309
- size = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale))
310
- key = self.split_key() if key is None else _formalize_key(key)
311
- dtype = dtype or environ.dftype()
312
- r = _loc_scale(loc, scale, jr.gumbel(key, shape=_size2shape(size), dtype=dtype))
313
- return r
314
-
315
- def laplace(self,
316
- loc=None,
294
+ a = _check_py_seq(a)
295
+ b = _check_py_seq(b)
296
+ if size is None:
297
+ size = lax.broadcast_shapes(jnp.shape(a), jnp.shape(b))
298
+ key = self.split_key() if key is None else _formalize_key(key)
299
+ dtype = dtype or environ.dftype()
300
+ r = jr.beta(key, a=a, b=b, shape=_size2shape(size), dtype=dtype)
301
+ return r
302
+
303
+ def exponential(self,
304
+ scale=None,
305
+ size: Optional[Size] = None,
306
+ key: Optional[SeedOrKey] = None,
307
+ dtype: DTypeLike = None):
308
+ if size is None:
309
+ size = jnp.shape(scale)
310
+ key = self.split_key() if key is None else _formalize_key(key)
311
+ dtype = dtype or environ.dftype()
312
+ scale = jnp.asarray(scale, dtype=dtype)
313
+ r = jr.exponential(key, shape=_size2shape(size), dtype=dtype)
314
+ if scale is not None:
315
+ r = r / scale
316
+ return r
317
+
318
+ def gamma(self,
319
+ shape,
317
320
  scale=None,
318
321
  size: Optional[Size] = None,
319
322
  key: Optional[SeedOrKey] = None,
320
323
  dtype: DTypeLike = None):
321
- loc = _check_py_seq(loc)
322
- scale = _check_py_seq(scale)
323
- if size is None:
324
- size = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale))
325
- key = self.split_key() if key is None else _formalize_key(key)
326
- dtype = dtype or environ.dftype()
327
- r = _loc_scale(loc, scale, jr.laplace(key, shape=_size2shape(size), dtype=dtype))
328
- return r
329
-
330
- def logistic(self,
324
+ shape = _check_py_seq(shape)
325
+ scale = _check_py_seq(scale)
326
+ if size is None:
327
+ size = lax.broadcast_shapes(jnp.shape(shape), jnp.shape(scale))
328
+ key = self.split_key() if key is None else _formalize_key(key)
329
+ dtype = dtype or environ.dftype()
330
+ r = jr.gamma(key, a=shape, shape=_size2shape(size), dtype=dtype)
331
+ if scale is not None:
332
+ r = r * scale
333
+ return r
334
+
335
+ def gumbel(self,
331
336
  loc=None,
332
337
  scale=None,
333
338
  size: Optional[Size] = None,
334
339
  key: Optional[SeedOrKey] = None,
335
340
  dtype: DTypeLike = None):
336
- loc = _check_py_seq(loc)
337
- scale = _check_py_seq(scale)
338
- if size is None:
339
- size = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale))
340
- key = self.split_key() if key is None else _formalize_key(key)
341
- dtype = dtype or environ.dftype()
342
- r = _loc_scale(loc, scale, jr.logistic(key, shape=_size2shape(size), dtype=dtype))
343
- return r
344
-
345
- def normal(self,
346
- loc=None,
347
- scale=None,
348
- size: Optional[Size] = None,
349
- key: Optional[SeedOrKey] = None,
350
- dtype: DTypeLike = None):
351
- loc = _check_py_seq(loc)
352
- scale = _check_py_seq(scale)
353
- if size is None:
354
- size = lax.broadcast_shapes(jnp.shape(scale), jnp.shape(loc))
355
- key = self.split_key() if key is None else _formalize_key(key)
356
- dtype = dtype or environ.dftype()
357
- r = _loc_scale(loc, scale, jr.normal(key, shape=_size2shape(size), dtype=dtype))
358
- return r
359
-
360
- def pareto(self,
361
- a,
362
- size: Optional[Size] = None,
363
- key: Optional[SeedOrKey] = None,
364
- dtype: DTypeLike = None):
365
- if size is None:
366
- size = jnp.shape(a)
367
- key = self.split_key() if key is None else _formalize_key(key)
368
- dtype = dtype or environ.dftype()
369
- a = jnp.asarray(a, dtype=dtype)
370
- r = jr.pareto(key, b=a, shape=_size2shape(size), dtype=dtype)
371
- return r
372
-
373
- def poisson(self,
374
- lam=1.0,
375
- size: Optional[Size] = None,
376
- key: Optional[SeedOrKey] = None,
377
- dtype: DTypeLike = None):
378
- lam = _check_py_seq(lam)
379
- if size is None:
380
- size = jnp.shape(lam)
381
- key = self.split_key() if key is None else _formalize_key(key)
382
- dtype = dtype or environ.ditype()
383
- r = jr.poisson(key, lam=lam, shape=_size2shape(size), dtype=dtype)
384
- return r
385
-
386
- def standard_cauchy(self,
387
- size: Optional[Size] = None,
388
- key: Optional[SeedOrKey] = None,
389
- dtype: DTypeLike = None):
390
- key = self.split_key() if key is None else _formalize_key(key)
391
- dtype = dtype or environ.dftype()
392
- r = jr.cauchy(key, shape=_size2shape(size), dtype=dtype)
393
- return r
394
-
395
- def standard_exponential(self,
396
- size: Optional[Size] = None,
397
- key: Optional[SeedOrKey] = None,
398
- dtype: DTypeLike = None):
399
- key = self.split_key() if key is None else _formalize_key(key)
400
- dtype = dtype or environ.dftype()
401
- r = jr.exponential(key, shape=_size2shape(size), dtype=dtype)
402
- return r
403
-
404
- def standard_gamma(self,
405
- shape,
406
- size: Optional[Size] = None,
407
- key: Optional[SeedOrKey] = None,
408
- dtype: DTypeLike = None):
409
- shape = _check_py_seq(shape)
410
- if size is None:
411
- size = jnp.shape(shape)
412
- key = self.split_key() if key is None else _formalize_key(key)
413
- dtype = dtype or environ.dftype()
414
- r = jr.gamma(key, a=shape, shape=_size2shape(size), dtype=dtype)
415
- return r
416
-
417
- def standard_normal(self,
418
- size: Optional[Size] = None,
419
- key: Optional[SeedOrKey] = None,
420
- dtype: DTypeLike = None):
421
- key = self.split_key() if key is None else _formalize_key(key)
422
- dtype = dtype or environ.dftype()
423
- r = jr.normal(key, shape=_size2shape(size), dtype=dtype)
424
- return r
425
-
426
- def standard_t(self, df,
341
+ loc = _check_py_seq(loc)
342
+ scale = _check_py_seq(scale)
343
+ if size is None:
344
+ size = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale))
345
+ key = self.split_key() if key is None else _formalize_key(key)
346
+ dtype = dtype or environ.dftype()
347
+ r = _loc_scale(loc, scale, jr.gumbel(key, shape=_size2shape(size), dtype=dtype))
348
+ return r
349
+
350
+ def laplace(self,
351
+ loc=None,
352
+ scale=None,
353
+ size: Optional[Size] = None,
354
+ key: Optional[SeedOrKey] = None,
355
+ dtype: DTypeLike = None):
356
+ loc = _check_py_seq(loc)
357
+ scale = _check_py_seq(scale)
358
+ if size is None:
359
+ size = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale))
360
+ key = self.split_key() if key is None else _formalize_key(key)
361
+ dtype = dtype or environ.dftype()
362
+ r = _loc_scale(loc, scale, jr.laplace(key, shape=_size2shape(size), dtype=dtype))
363
+ return r
364
+
365
+ def logistic(self,
366
+ loc=None,
367
+ scale=None,
427
368
  size: Optional[Size] = None,
428
369
  key: Optional[SeedOrKey] = None,
429
370
  dtype: DTypeLike = None):
430
- df = _check_py_seq(df)
431
- if size is None:
432
- size = jnp.shape(size)
433
- key = self.split_key() if key is None else _formalize_key(key)
434
- dtype = dtype or environ.dftype()
435
- r = jr.t(key, df=df, shape=_size2shape(size), dtype=dtype)
436
- return r
437
-
438
- def uniform(self,
439
- low=0.0,
440
- high=1.0,
441
- size: Optional[Size] = None,
442
- key: Optional[SeedOrKey] = None,
443
- dtype: DTypeLike = None):
444
- low = _check_py_seq(low)
445
- high = _check_py_seq(high)
446
- if size is None:
447
- size = lax.broadcast_shapes(jnp.shape(low), jnp.shape(high))
448
- key = self.split_key() if key is None else _formalize_key(key)
449
- dtype = dtype or environ.dftype()
450
- r = uniform_for_unit(key, shape=_size2shape(size), minval=low, maxval=high, dtype=dtype)
451
- return r
452
-
453
- def __norm_cdf(self, x, sqrt2, dtype):
454
- # Computes standard normal cumulative distribution function
455
- return (np.asarray(1., dtype) + lax.erf(x / sqrt2)) / np.asarray(2., dtype)
456
-
457
- def truncated_normal(
458
- self,
459
- lower,
460
- upper,
461
- size: Optional[Size] = None,
462
- loc=0.,
463
- scale=1.,
464
- key: Optional[SeedOrKey] = None,
465
- dtype: DTypeLike = None
466
- ):
467
- lower = _check_py_seq(lower)
468
- upper = _check_py_seq(upper)
469
- loc = _check_py_seq(loc)
470
- scale = _check_py_seq(scale)
471
- dtype = dtype or environ.dftype()
472
-
473
- lower = u.math.asarray(lower, dtype=dtype)
474
- upper = u.math.asarray(upper, dtype=dtype)
475
- loc = u.math.asarray(loc, dtype=dtype)
476
- scale = u.math.asarray(scale, dtype=dtype)
477
- unit = u.get_unit(lower)
478
- lower, upper, loc, scale = (
479
- lower.mantissa if isinstance(lower, u.Quantity) else lower,
480
- u.Quantity(upper).in_unit(unit).mantissa,
481
- u.Quantity(loc).in_unit(unit).mantissa,
482
- u.Quantity(scale).in_unit(unit).mantissa
483
- )
484
-
485
- jit_error_if(
486
- u.math.any(u.math.logical_or(loc < lower - 2 * scale, loc > upper + 2 * scale)),
487
- "mean is more than 2 std from [lower, upper] in truncated_normal. "
488
- "The distribution of values may be incorrect."
489
- )
490
-
491
- if size is None:
492
- size = u.math.broadcast_shapes(jnp.shape(lower),
493
- jnp.shape(upper),
494
- jnp.shape(loc),
495
- jnp.shape(scale))
496
-
497
- # Values are generated by using a truncated uniform distribution and
498
- # then using the inverse CDF for the normal distribution.
499
- # Get upper and lower cdf values
500
- sqrt2 = np.array(np.sqrt(2), dtype=dtype)
501
- l = self.__norm_cdf((lower - loc) / scale, sqrt2, dtype)
502
- u_ = self.__norm_cdf((upper - loc) / scale, sqrt2, dtype)
503
-
504
- # Uniformly fill tensor with values from [l, u], then translate to
505
- # [2l-1, 2u-1].
506
- key = self.split_key() if key is None else _formalize_key(key)
507
- out = uniform_for_unit(
508
- key, size, dtype,
509
- minval=lax.nextafter(2 * l - 1, np.array(np.inf, dtype=dtype)),
510
- maxval=lax.nextafter(2 * u_- 1, np.array(-np.inf, dtype=dtype))
511
- )
512
-
513
- # Use inverse cdf transform for normal distribution to get truncated
514
- # standard normal
515
- out = lax.erf_inv(out)
516
-
517
- # Transform to proper mean, std
518
- out = out * scale * sqrt2 + loc
519
-
520
- # Clamp to ensure it's in the proper range
521
- out = jnp.clip(
522
- out,
523
- lax.nextafter(lax.stop_gradient(lower), np.array(np.inf, dtype=dtype)),
524
- lax.nextafter(lax.stop_gradient(upper), np.array(-np.inf, dtype=dtype))
525
- )
526
- return out if unit.is_unitless else u.Quantity(out, unit=unit)
527
-
528
- def _check_p(self, p):
529
- raise ValueError(f'Parameter p should be within [0, 1], but we got {p}')
530
-
531
- def bernoulli(self,
532
- p,
533
- size: Optional[Size] = None,
534
- key: Optional[SeedOrKey] = None):
535
- p = _check_py_seq(p)
536
- jit_error_if(jnp.any(jnp.logical_and(p < 0, p > 1)), self._check_p, p)
537
- if size is None:
538
- size = jnp.shape(p)
539
- key = self.split_key() if key is None else _formalize_key(key)
540
- r = jr.bernoulli(key, p=p, shape=_size2shape(size))
541
- return r
542
-
543
- def lognormal(
544
- self,
545
- mean=None,
546
- sigma=None,
547
- size: Optional[Size] = None,
548
- key: Optional[SeedOrKey] = None,
549
- dtype: DTypeLike = None
550
- ):
551
- mean = _check_py_seq(mean)
552
- sigma = _check_py_seq(sigma)
553
- mean = u.math.asarray(mean, dtype=dtype)
554
- sigma = u.math.asarray(sigma, dtype=dtype)
555
- unit = mean.unit if isinstance(mean, u.Quantity) else u.Unit()
556
- mean = mean.mantissa if isinstance(mean, u.Quantity) else mean
557
- sigma = sigma.in_unit(unit).mantissa if isinstance(sigma, u.Quantity) else sigma
558
-
559
- if size is None:
560
- size = jnp.broadcast_shapes(
561
- jnp.shape(mean),
562
- jnp.shape(sigma)
563
- )
564
- key = self.split_key() if key is None else _formalize_key(key)
565
- dtype = dtype or environ.dftype()
566
- samples = jr.normal(key, shape=_size2shape(size), dtype=dtype)
567
- samples = _loc_scale(mean, sigma, samples)
568
- samples = jnp.exp(samples)
569
- return samples if unit.is_unitless else u.Quantity(samples, unit=unit)
570
-
571
- def binomial(self,
572
- n,
573
- p,
371
+ loc = _check_py_seq(loc)
372
+ scale = _check_py_seq(scale)
373
+ if size is None:
374
+ size = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale))
375
+ key = self.split_key() if key is None else _formalize_key(key)
376
+ dtype = dtype or environ.dftype()
377
+ r = _loc_scale(loc, scale, jr.logistic(key, shape=_size2shape(size), dtype=dtype))
378
+ return r
379
+
380
+ def normal(self,
381
+ loc=None,
382
+ scale=None,
574
383
  size: Optional[Size] = None,
575
384
  key: Optional[SeedOrKey] = None,
576
385
  dtype: DTypeLike = None):
577
- n = _check_py_seq(n)
578
- p = _check_py_seq(p)
579
- jit_error_if(jnp.any(jnp.logical_and(p < 0, p > 1)), self._check_p, p)
580
- if size is None:
581
- size = jnp.broadcast_shapes(jnp.shape(n), jnp.shape(p))
582
- key = self.split_key() if key is None else _formalize_key(key)
583
- r = _binomial(key, p, n, shape=_size2shape(size))
584
- dtype = dtype or environ.ditype()
585
- return jnp.asarray(r, dtype=dtype)
586
-
587
- def chisquare(self,
588
- df,
589
- size: Optional[Size] = None,
590
- key: Optional[SeedOrKey] = None,
591
- dtype: DTypeLike = None):
592
- df = _check_py_seq(df)
593
- key = self.split_key() if key is None else _formalize_key(key)
594
- dtype = dtype or environ.dftype()
595
- if size is None:
596
- if jnp.ndim(df) == 0:
597
- dist = jr.normal(key, (df,), dtype=dtype) ** 2
598
- dist = dist.sum()
599
- else:
600
- raise NotImplementedError('Do not support non-scale "df" when "size" is None')
601
- else:
602
- dist = jr.normal(key, (df,) + _size2shape(size), dtype=dtype) ** 2
603
- dist = dist.sum(axis=0)
604
- return dist
605
-
606
- def dirichlet(self,
607
- alpha,
386
+ loc = _check_py_seq(loc)
387
+ scale = _check_py_seq(scale)
388
+ if size is None:
389
+ size = lax.broadcast_shapes(jnp.shape(scale), jnp.shape(loc))
390
+ key = self.split_key() if key is None else _formalize_key(key)
391
+ dtype = dtype or environ.dftype()
392
+ r = _loc_scale(loc, scale, jr.normal(key, shape=_size2shape(size), dtype=dtype))
393
+ return r
394
+
395
+ def pareto(self,
396
+ a,
397
+ size: Optional[Size] = None,
398
+ key: Optional[SeedOrKey] = None,
399
+ dtype: DTypeLike = None):
400
+ if size is None:
401
+ size = jnp.shape(a)
402
+ key = self.split_key() if key is None else _formalize_key(key)
403
+ dtype = dtype or environ.dftype()
404
+ a = jnp.asarray(a, dtype=dtype)
405
+ r = jr.pareto(key, b=a, shape=_size2shape(size), dtype=dtype)
406
+ return r
407
+
408
+ def poisson(self,
409
+ lam=1.0,
608
410
  size: Optional[Size] = None,
609
411
  key: Optional[SeedOrKey] = None,
610
412
  dtype: DTypeLike = None):
611
- key = self.split_key() if key is None else _formalize_key(key)
612
- alpha = _check_py_seq(alpha)
613
- dtype = dtype or environ.dftype()
614
- r = jr.dirichlet(key, alpha=alpha, shape=_size2shape(size), dtype=dtype)
615
- return r
413
+ lam = _check_py_seq(lam)
414
+ if size is None:
415
+ size = jnp.shape(lam)
416
+ key = self.split_key() if key is None else _formalize_key(key)
417
+ dtype = dtype or environ.ditype()
418
+ r = jr.poisson(key, lam=lam, shape=_size2shape(size), dtype=dtype)
419
+ return r
420
+
421
+ def standard_cauchy(self,
422
+ size: Optional[Size] = None,
423
+ key: Optional[SeedOrKey] = None,
424
+ dtype: DTypeLike = None):
425
+ key = self.split_key() if key is None else _formalize_key(key)
426
+ dtype = dtype or environ.dftype()
427
+ r = jr.cauchy(key, shape=_size2shape(size), dtype=dtype)
428
+ return r
429
+
430
+ def standard_exponential(self,
431
+ size: Optional[Size] = None,
432
+ key: Optional[SeedOrKey] = None,
433
+ dtype: DTypeLike = None):
434
+ key = self.split_key() if key is None else _formalize_key(key)
435
+ dtype = dtype or environ.dftype()
436
+ r = jr.exponential(key, shape=_size2shape(size), dtype=dtype)
437
+ return r
438
+
439
+ def standard_gamma(self,
440
+ shape,
441
+ size: Optional[Size] = None,
442
+ key: Optional[SeedOrKey] = None,
443
+ dtype: DTypeLike = None):
444
+ shape = _check_py_seq(shape)
445
+ if size is None:
446
+ size = jnp.shape(shape)
447
+ key = self.split_key() if key is None else _formalize_key(key)
448
+ dtype = dtype or environ.dftype()
449
+ r = jr.gamma(key, a=shape, shape=_size2shape(size), dtype=dtype)
450
+ return r
451
+
452
+ def standard_normal(self,
453
+ size: Optional[Size] = None,
454
+ key: Optional[SeedOrKey] = None,
455
+ dtype: DTypeLike = None):
456
+ key = self.split_key() if key is None else _formalize_key(key)
457
+ dtype = dtype or environ.dftype()
458
+ r = jr.normal(key, shape=_size2shape(size), dtype=dtype)
459
+ return r
616
460
 
617
- def geometric(self,
618
- p,
461
+ def standard_t(self, df,
462
+ size: Optional[Size] = None,
463
+ key: Optional[SeedOrKey] = None,
464
+ dtype: DTypeLike = None):
465
+ df = _check_py_seq(df)
466
+ if size is None:
467
+ size = jnp.shape(size)
468
+ key = self.split_key() if key is None else _formalize_key(key)
469
+ dtype = dtype or environ.dftype()
470
+ r = jr.t(key, df=df, shape=_size2shape(size), dtype=dtype)
471
+ return r
472
+
473
+ def uniform(self,
474
+ low=0.0,
475
+ high=1.0,
619
476
  size: Optional[Size] = None,
620
477
  key: Optional[SeedOrKey] = None,
621
478
  dtype: DTypeLike = None):
622
- p = _check_py_seq(p)
623
- if size is None:
624
- size = jnp.shape(p)
625
- key = self.split_key() if key is None else _formalize_key(key)
626
- dtype = dtype or environ.dftype()
627
- u_ = uniform_for_unit(key, size, dtype=dtype)
628
- r = jnp.floor(jnp.log1p(-u_) / jnp.log1p(-p))
629
- return r
630
-
631
- def _check_p2(self, p):
632
- raise ValueError(f'We require `sum(pvals[:-1]) <= 1`. But we got {p}')
633
-
634
- def multinomial(self,
635
- n,
636
- pvals,
479
+ low = _check_py_seq(low)
480
+ high = _check_py_seq(high)
481
+ if size is None:
482
+ size = lax.broadcast_shapes(jnp.shape(low), jnp.shape(high))
483
+ key = self.split_key() if key is None else _formalize_key(key)
484
+ dtype = dtype or environ.dftype()
485
+ r = uniform_for_unit(key, shape=_size2shape(size), minval=low, maxval=high, dtype=dtype)
486
+ return r
487
+
488
+ def __norm_cdf(self, x, sqrt2, dtype):
489
+ # Computes standard normal cumulative distribution function
490
+ return (np.asarray(1., dtype) + lax.erf(x / sqrt2)) / np.asarray(2., dtype)
491
+
492
+ def truncated_normal(
493
+ self,
494
+ lower,
495
+ upper,
496
+ size: Optional[Size] = None,
497
+ loc=0.,
498
+ scale=1.,
499
+ key: Optional[SeedOrKey] = None,
500
+ dtype: DTypeLike = None
501
+ ):
502
+ lower = _check_py_seq(lower)
503
+ upper = _check_py_seq(upper)
504
+ loc = _check_py_seq(loc)
505
+ scale = _check_py_seq(scale)
506
+ dtype = dtype or environ.dftype()
507
+
508
+ lower = u.math.asarray(lower, dtype=dtype)
509
+ upper = u.math.asarray(upper, dtype=dtype)
510
+ loc = u.math.asarray(loc, dtype=dtype)
511
+ scale = u.math.asarray(scale, dtype=dtype)
512
+ unit = u.get_unit(lower)
513
+ lower, upper, loc, scale = (
514
+ lower.mantissa if isinstance(lower, u.Quantity) else lower,
515
+ u.Quantity(upper).in_unit(unit).mantissa,
516
+ u.Quantity(loc).in_unit(unit).mantissa,
517
+ u.Quantity(scale).in_unit(unit).mantissa
518
+ )
519
+
520
+ jit_error_if(
521
+ u.math.any(u.math.logical_or(loc < lower - 2 * scale, loc > upper + 2 * scale)),
522
+ "mean is more than 2 std from [lower, upper] in truncated_normal. "
523
+ "The distribution of values may be incorrect."
524
+ )
525
+
526
+ if size is None:
527
+ size = u.math.broadcast_shapes(jnp.shape(lower),
528
+ jnp.shape(upper),
529
+ jnp.shape(loc),
530
+ jnp.shape(scale))
531
+
532
+ # Values are generated by using a truncated uniform distribution and
533
+ # then using the inverse CDF for the normal distribution.
534
+ # Get upper and lower cdf values
535
+ sqrt2 = np.array(np.sqrt(2), dtype=dtype)
536
+ l = self.__norm_cdf((lower - loc) / scale, sqrt2, dtype)
537
+ u_ = self.__norm_cdf((upper - loc) / scale, sqrt2, dtype)
538
+
539
+ # Uniformly fill tensor with values from [l, u], then translate to
540
+ # [2l-1, 2u-1].
541
+ key = self.split_key() if key is None else _formalize_key(key)
542
+ out = uniform_for_unit(
543
+ key, size, dtype,
544
+ minval=lax.nextafter(2 * l - 1, np.array(np.inf, dtype=dtype)),
545
+ maxval=lax.nextafter(2 * u_ - 1, np.array(-np.inf, dtype=dtype))
546
+ )
547
+
548
+ # Use inverse cdf transform for normal distribution to get truncated
549
+ # standard normal
550
+ out = lax.erf_inv(out)
551
+
552
+ # Transform to proper mean, std
553
+ out = out * scale * sqrt2 + loc
554
+
555
+ # Clamp to ensure it's in the proper range
556
+ out = jnp.clip(
557
+ out,
558
+ lax.nextafter(lax.stop_gradient(lower), np.array(np.inf, dtype=dtype)),
559
+ lax.nextafter(lax.stop_gradient(upper), np.array(-np.inf, dtype=dtype))
560
+ )
561
+ return out if unit.is_unitless else u.Quantity(out, unit=unit)
562
+
563
+ def _check_p(self, p):
564
+ raise ValueError(f'Parameter p should be within [0, 1], but we got {p}')
565
+
566
+ def bernoulli(self,
567
+ p,
568
+ size: Optional[Size] = None,
569
+ key: Optional[SeedOrKey] = None):
570
+ p = _check_py_seq(p)
571
+ jit_error_if(jnp.any(jnp.logical_and(p < 0, p > 1)), self._check_p, p)
572
+ if size is None:
573
+ size = jnp.shape(p)
574
+ key = self.split_key() if key is None else _formalize_key(key)
575
+ r = jr.bernoulli(key, p=p, shape=_size2shape(size))
576
+ return r
577
+
578
+ def lognormal(
579
+ self,
580
+ mean=None,
581
+ sigma=None,
582
+ size: Optional[Size] = None,
583
+ key: Optional[SeedOrKey] = None,
584
+ dtype: DTypeLike = None
585
+ ):
586
+ mean = _check_py_seq(mean)
587
+ sigma = _check_py_seq(sigma)
588
+ mean = u.math.asarray(mean, dtype=dtype)
589
+ sigma = u.math.asarray(sigma, dtype=dtype)
590
+ unit = mean.unit if isinstance(mean, u.Quantity) else u.Unit()
591
+ mean = mean.mantissa if isinstance(mean, u.Quantity) else mean
592
+ sigma = sigma.in_unit(unit).mantissa if isinstance(sigma, u.Quantity) else sigma
593
+
594
+ if size is None:
595
+ size = jnp.broadcast_shapes(
596
+ jnp.shape(mean),
597
+ jnp.shape(sigma)
598
+ )
599
+ key = self.split_key() if key is None else _formalize_key(key)
600
+ dtype = dtype or environ.dftype()
601
+ samples = jr.normal(key, shape=_size2shape(size), dtype=dtype)
602
+ samples = _loc_scale(mean, sigma, samples)
603
+ samples = jnp.exp(samples)
604
+ return samples if unit.is_unitless else u.Quantity(samples, unit=unit)
605
+
606
+ def binomial(self,
607
+ n,
608
+ p,
609
+ size: Optional[Size] = None,
610
+ key: Optional[SeedOrKey] = None,
611
+ dtype: DTypeLike = None):
612
+ n = _check_py_seq(n)
613
+ p = _check_py_seq(p)
614
+ jit_error_if(jnp.any(jnp.logical_and(p < 0, p > 1)), self._check_p, p)
615
+ if size is None:
616
+ size = jnp.broadcast_shapes(jnp.shape(n), jnp.shape(p))
617
+ key = self.split_key() if key is None else _formalize_key(key)
618
+ r = _binomial(key, p, n, shape=_size2shape(size))
619
+ dtype = dtype or environ.ditype()
620
+ return jnp.asarray(r, dtype=dtype)
621
+
622
+ def chisquare(self,
623
+ df,
637
624
  size: Optional[Size] = None,
638
625
  key: Optional[SeedOrKey] = None,
639
626
  dtype: DTypeLike = None):
640
- key = self.split_key() if key is None else _formalize_key(key)
641
- n = _check_py_seq(n)
642
- pvals = _check_py_seq(pvals)
643
- jit_error_if(jnp.sum(pvals[:-1]) > 1., self._check_p2, pvals)
644
- if isinstance(n, jax.core.Tracer):
645
- raise ValueError("The total count parameter `n` should not be a jax abstract array.")
646
- size = _size2shape(size)
647
- n_max = int(np.max(jax.device_get(n)))
648
- batch_shape = lax.broadcast_shapes(jnp.shape(pvals)[:-1], jnp.shape(n))
649
- r = _multinomial(key, pvals, n, n_max, batch_shape + size)
650
- dtype = dtype or environ.ditype()
651
- return jnp.asarray(r, dtype=dtype)
652
-
653
- def multivariate_normal(
654
- self,
655
- mean,
656
- cov,
657
- size: Optional[Size] = None,
658
- method: str = 'cholesky',
659
- key: Optional[SeedOrKey] = None,
660
- dtype: DTypeLike = None
661
- ):
662
- if method not in {'svd', 'eigh', 'cholesky'}:
663
- raise ValueError("method must be one of {'svd', 'eigh', 'cholesky'}")
664
- dtype = dtype or environ.dftype()
665
- mean = u.math.asarray(_check_py_seq(mean), dtype=dtype)
666
- cov = u.math.asarray(_check_py_seq(cov), dtype=dtype)
667
- if isinstance(mean, u.Quantity):
668
- assert isinstance(cov, u.Quantity)
669
- assert mean.unit ** 2 == cov.unit
670
- mean = mean.mantissa if isinstance(mean, u.Quantity) else mean
671
- cov = cov.mantissa if isinstance(cov, u.Quantity) else cov
672
- unit = mean.unit if isinstance(mean, u.Quantity) else u.Unit()
673
-
674
- key = self.split_key() if key is None else _formalize_key(key)
675
- if not jnp.ndim(mean) >= 1:
676
- raise ValueError(f"multivariate_normal requires mean.ndim >= 1, got mean.ndim == {jnp.ndim(mean)}")
677
- if not jnp.ndim(cov) >= 2:
678
- raise ValueError(f"multivariate_normal requires cov.ndim >= 2, got cov.ndim == {jnp.ndim(cov)}")
679
- n = mean.shape[-1]
680
- if jnp.shape(cov)[-2:] != (n, n):
681
- raise ValueError(f"multivariate_normal requires cov.shape == (..., n, n) for n={n}, "
682
- f"but got cov.shape == {jnp.shape(cov)}.")
683
- if size is None:
684
- size = lax.broadcast_shapes(mean.shape[:-1], cov.shape[:-2])
685
- else:
686
- size = _size2shape(size)
687
- _check_shape("normal", size, mean.shape[:-1], cov.shape[:-2])
688
-
689
- if method == 'svd':
690
- (u_, s, _) = jnp.linalg.svd(cov)
691
- factor = u_ * jnp.sqrt(s[..., None, :])
692
- elif method == 'eigh':
693
- (w, v) = jnp.linalg.eigh(cov)
694
- factor = v * jnp.sqrt(w[..., None, :])
695
- else: # 'cholesky'
696
- factor = jnp.linalg.cholesky(cov)
697
- normal_samples = jr.normal(key, size + mean.shape[-1:], dtype=dtype)
698
- r = mean + jnp.einsum('...ij,...j->...i', factor, normal_samples)
699
- return r if unit.is_unitless else u.Quantity(r, unit=unit)
700
-
701
- def rayleigh(self,
702
- scale=1.0,
703
- size: Optional[Size] = None,
704
- key: Optional[SeedOrKey] = None,
705
- dtype: DTypeLike = None):
706
- scale = _check_py_seq(scale)
707
- if size is None:
708
- size = jnp.shape(scale)
709
- key = self.split_key() if key is None else _formalize_key(key)
710
- dtype = dtype or environ.dftype()
711
- x = jnp.sqrt(-2. * jnp.log(uniform_for_unit(key, shape=_size2shape(size), minval=0, maxval=1, dtype=dtype)))
712
- r = x * scale
713
- return r
714
-
715
- def triangular(self,
716
- size: Optional[Size] = None,
717
- key: Optional[SeedOrKey] = None):
718
- key = self.split_key() if key is None else _formalize_key(key)
719
- bernoulli_samples = jr.bernoulli(key, p=0.5, shape=_size2shape(size))
720
- r = 2 * bernoulli_samples - 1
721
- return r
722
-
723
- def vonmises(self,
724
- mu,
725
- kappa,
726
- size: Optional[Size] = None,
727
- key: Optional[SeedOrKey] = None,
728
- dtype: DTypeLike = None):
729
- key = self.split_key() if key is None else _formalize_key(key)
730
- dtype = dtype or environ.dftype()
731
- mu = jnp.asarray(_check_py_seq(mu), dtype=dtype)
732
- kappa = jnp.asarray(_check_py_seq(kappa), dtype=dtype)
733
- if size is None:
734
- size = lax.broadcast_shapes(jnp.shape(mu), jnp.shape(kappa))
735
- size = _size2shape(size)
736
- samples = _von_mises_centered(key, kappa, size, dtype=dtype)
737
- samples = samples + mu
738
- samples = (samples + jnp.pi) % (2.0 * jnp.pi) - jnp.pi
739
- return samples
740
-
741
- def weibull(self,
742
- a,
743
- size: Optional[Size] = None,
744
- key: Optional[SeedOrKey] = None,
745
- dtype: DTypeLike = None):
746
- key = self.split_key() if key is None else _formalize_key(key)
747
- a = _check_py_seq(a)
748
- if size is None:
749
- size = jnp.shape(a)
750
- else:
751
- if jnp.size(a) > 1:
752
- raise ValueError(f'"a" should be a scalar when "size" is provided. But we got {a}')
753
- size = _size2shape(size)
754
- dtype = dtype or environ.dftype()
755
- random_uniform = uniform_for_unit(key=key, shape=size, minval=0, maxval=1, dtype=dtype)
756
- r = jnp.power(-jnp.log1p(-random_uniform), 1.0 / a)
757
- return r
758
-
759
- def weibull_min(self,
760
- a,
761
- scale=None,
627
+ df = _check_py_seq(df)
628
+ key = self.split_key() if key is None else _formalize_key(key)
629
+ dtype = dtype or environ.dftype()
630
+ if size is None:
631
+ if jnp.ndim(df) == 0:
632
+ dist = jr.normal(key, (df,), dtype=dtype) ** 2
633
+ dist = dist.sum()
634
+ else:
635
+ raise NotImplementedError('Do not support non-scale "df" when "size" is None')
636
+ else:
637
+ dist = jr.normal(key, (df,) + _size2shape(size), dtype=dtype) ** 2
638
+ dist = dist.sum(axis=0)
639
+ return dist
640
+
641
+ def dirichlet(self,
642
+ alpha,
762
643
  size: Optional[Size] = None,
763
644
  key: Optional[SeedOrKey] = None,
764
645
  dtype: DTypeLike = None):
765
- key = self.split_key() if key is None else _formalize_key(key)
766
- a = _check_py_seq(a)
767
- scale = _check_py_seq(scale)
768
- if size is None:
769
- size = jnp.broadcast_shapes(jnp.shape(a), jnp.shape(scale))
770
- else:
771
- if jnp.size(a) > 1:
772
- raise ValueError(f'"a" should be a scalar when "size" is provided. But we got {a}')
773
- size = _size2shape(size)
774
- dtype = dtype or environ.dftype()
775
- random_uniform = uniform_for_unit(key=key, shape=size, minval=0, maxval=1, dtype=dtype)
776
- r = jnp.power(-jnp.log1p(-random_uniform), 1.0 / a)
777
- if scale is not None:
778
- r /= scale
779
- return r
780
-
781
- def maxwell(self,
782
- size: Optional[Size] = None,
783
- key: Optional[SeedOrKey] = None,
784
- dtype: DTypeLike = None):
785
- key = self.split_key() if key is None else _formalize_key(key)
786
- shape = _size2shape(size) + (3,)
787
- dtype = dtype or environ.dftype()
788
- norm_rvs = jr.normal(key=key, shape=shape, dtype=dtype)
789
- r = jnp.linalg.norm(norm_rvs, axis=-1)
790
- return r
791
-
792
- def negative_binomial(self,
793
- n,
794
- p,
795
- size: Optional[Size] = None,
796
- key: Optional[SeedOrKey] = None,
797
- dtype: DTypeLike = None):
798
- n = _check_py_seq(n)
799
- p = _check_py_seq(p)
800
- if size is None:
801
- size = lax.broadcast_shapes(jnp.shape(n), jnp.shape(p))
802
- size = _size2shape(size)
803
- logits = jnp.log(p) - jnp.log1p(-p)
804
- if key is None:
805
- keys = self.split_keys(2)
806
- else:
807
- keys = jr.split(_formalize_key(key), 2)
808
- rate = self.gamma(shape=n, scale=jnp.exp(-logits), size=size, key=keys[0], dtype=environ.dftype())
809
- r = self.poisson(lam=rate, key=keys[1], dtype=dtype or environ.ditype())
810
- return r
811
-
812
- def wald(self,
813
- mean,
814
- scale,
815
- size: Optional[Size] = None,
816
- key: Optional[SeedOrKey] = None,
817
- dtype: DTypeLike = None):
818
- dtype = dtype or environ.dftype()
819
- key = self.split_key() if key is None else _formalize_key(key)
820
- mean = jnp.asarray(_check_py_seq(mean), dtype=dtype)
821
- scale = jnp.asarray(_check_py_seq(scale), dtype=dtype)
822
- if size is None:
823
- size = lax.broadcast_shapes(jnp.shape(mean), jnp.shape(scale))
824
- size = _size2shape(size)
825
- sampled_chi2 = jnp.square(self.randn(*size))
826
- sampled_uniform = self.uniform(size=size, key=key, dtype=dtype)
827
- # Wikipedia defines an intermediate x with the formula
828
- # x = loc + loc ** 2 * y / (2 * conc) - loc / (2 * conc) * sqrt(4 * loc * conc * y + loc ** 2 * y ** 2)
829
- # where y ~ N(0, 1)**2 (sampled_chi2 above) and conc is the concentration.
830
- # Let us write
831
- # w = loc * y / (2 * conc)
832
- # Then we can extract the common factor in the last two terms to obtain
833
- # x = loc + loc * w * (1 - sqrt(2 / w + 1))
834
- # Now we see that the Wikipedia formula suffers from catastrphic
835
- # cancellation for large w (e.g., if conc << loc).
836
- #
837
- # Fortunately, we can fix this by multiplying both sides
838
- # by 1 + sqrt(2 / w + 1). We get
839
- # x * (1 + sqrt(2 / w + 1)) =
840
- # = loc * (1 + sqrt(2 / w + 1)) + loc * w * (1 - (2 / w + 1))
841
- # = loc * (sqrt(2 / w + 1) - 1)
842
- # The term sqrt(2 / w + 1) + 1 no longer presents numerical
843
- # difficulties for large w, and sqrt(2 / w + 1) - 1 is just
844
- # sqrt1pm1(2 / w), which we know how to compute accurately.
845
- # This just leaves the matter of small w, where 2 / w may
846
- # overflow. In the limit a w -> 0, x -> loc, so we just mask
847
- # that case.
848
- sqrt1pm1_arg = 4 * scale / (mean * sampled_chi2) # 2 / w above
849
- safe_sqrt1pm1_arg = jnp.where(sqrt1pm1_arg < np.inf, sqrt1pm1_arg, 1.0)
850
- denominator = 1.0 + jnp.sqrt(safe_sqrt1pm1_arg + 1.0)
851
- ratio = jnp.expm1(0.5 * jnp.log1p(safe_sqrt1pm1_arg)) / denominator
852
- sampled = mean * jnp.where(sqrt1pm1_arg < np.inf, ratio, 1.0) # x above
853
- res = jnp.where(sampled_uniform <= mean / (mean + sampled),
854
- sampled,
855
- jnp.square(mean) / sampled)
856
- return res
857
-
858
- def t(self,
859
- df,
646
+ key = self.split_key() if key is None else _formalize_key(key)
647
+ alpha = _check_py_seq(alpha)
648
+ dtype = dtype or environ.dftype()
649
+ r = jr.dirichlet(key, alpha=alpha, shape=_size2shape(size), dtype=dtype)
650
+ return r
651
+
652
+ def geometric(self,
653
+ p,
654
+ size: Optional[Size] = None,
655
+ key: Optional[SeedOrKey] = None,
656
+ dtype: DTypeLike = None):
657
+ p = _check_py_seq(p)
658
+ if size is None:
659
+ size = jnp.shape(p)
660
+ key = self.split_key() if key is None else _formalize_key(key)
661
+ dtype = dtype or environ.dftype()
662
+ u_ = uniform_for_unit(key, size, dtype=dtype)
663
+ r = jnp.floor(jnp.log1p(-u_) / jnp.log1p(-p))
664
+ return r
665
+
666
+ def _check_p2(self, p):
667
+ raise ValueError(f'We require `sum(pvals[:-1]) <= 1`. But we got {p}')
668
+
669
+ def multinomial(self,
670
+ n,
671
+ pvals,
672
+ size: Optional[Size] = None,
673
+ key: Optional[SeedOrKey] = None,
674
+ dtype: DTypeLike = None):
675
+ key = self.split_key() if key is None else _formalize_key(key)
676
+ n = _check_py_seq(n)
677
+ pvals = _check_py_seq(pvals)
678
+ jit_error_if(jnp.sum(pvals[:-1]) > 1., self._check_p2, pvals)
679
+ if isinstance(n, jax.core.Tracer):
680
+ raise ValueError("The total count parameter `n` should not be a jax abstract array.")
681
+ size = _size2shape(size)
682
+ n_max = int(np.max(jax.device_get(n)))
683
+ batch_shape = lax.broadcast_shapes(jnp.shape(pvals)[:-1], jnp.shape(n))
684
+ r = _multinomial(key, pvals, n, n_max, batch_shape + size)
685
+ dtype = dtype or environ.ditype()
686
+ return jnp.asarray(r, dtype=dtype)
687
+
688
+ def multivariate_normal(
689
+ self,
690
+ mean,
691
+ cov,
860
692
  size: Optional[Size] = None,
693
+ method: str = 'cholesky',
861
694
  key: Optional[SeedOrKey] = None,
862
- dtype: DTypeLike = None):
863
- dtype = dtype or environ.dftype()
864
- df = jnp.asarray(_check_py_seq(df), dtype=dtype)
865
- if size is None:
866
- size = np.shape(df)
867
- else:
868
- size = _size2shape(size)
869
- _check_shape("t", size, np.shape(df))
870
- if key is None:
871
- keys = self.split_keys(2)
872
- else:
873
- keys = jr.split(_formalize_key(key), 2)
874
- n = jr.normal(keys[0], size, dtype=dtype)
875
- two = _const(n, 2)
876
- half_df = lax.div(df, two)
877
- g = jr.gamma(keys[1], half_df, size, dtype=dtype)
878
- r = n * jnp.sqrt(half_df / g)
879
- return r
880
-
881
- def orthogonal(self,
882
- n: int,
695
+ dtype: DTypeLike = None
696
+ ):
697
+ if method not in {'svd', 'eigh', 'cholesky'}:
698
+ raise ValueError("method must be one of {'svd', 'eigh', 'cholesky'}")
699
+ dtype = dtype or environ.dftype()
700
+ mean = u.math.asarray(_check_py_seq(mean), dtype=dtype)
701
+ cov = u.math.asarray(_check_py_seq(cov), dtype=dtype)
702
+ if isinstance(mean, u.Quantity):
703
+ assert isinstance(cov, u.Quantity)
704
+ assert mean.unit ** 2 == cov.unit
705
+ mean = mean.mantissa if isinstance(mean, u.Quantity) else mean
706
+ cov = cov.mantissa if isinstance(cov, u.Quantity) else cov
707
+ unit = mean.unit if isinstance(mean, u.Quantity) else u.Unit()
708
+
709
+ key = self.split_key() if key is None else _formalize_key(key)
710
+ if not jnp.ndim(mean) >= 1:
711
+ raise ValueError(f"multivariate_normal requires mean.ndim >= 1, got mean.ndim == {jnp.ndim(mean)}")
712
+ if not jnp.ndim(cov) >= 2:
713
+ raise ValueError(f"multivariate_normal requires cov.ndim >= 2, got cov.ndim == {jnp.ndim(cov)}")
714
+ n = mean.shape[-1]
715
+ if jnp.shape(cov)[-2:] != (n, n):
716
+ raise ValueError(f"multivariate_normal requires cov.shape == (..., n, n) for n={n}, "
717
+ f"but got cov.shape == {jnp.shape(cov)}.")
718
+ if size is None:
719
+ size = lax.broadcast_shapes(mean.shape[:-1], cov.shape[:-2])
720
+ else:
721
+ size = _size2shape(size)
722
+ _check_shape("normal", size, mean.shape[:-1], cov.shape[:-2])
723
+
724
+ if method == 'svd':
725
+ (u_, s, _) = jnp.linalg.svd(cov)
726
+ factor = u_ * jnp.sqrt(s[..., None, :])
727
+ elif method == 'eigh':
728
+ (w, v) = jnp.linalg.eigh(cov)
729
+ factor = v * jnp.sqrt(w[..., None, :])
730
+ else: # 'cholesky'
731
+ factor = jnp.linalg.cholesky(cov)
732
+ normal_samples = jr.normal(key, size + mean.shape[-1:], dtype=dtype)
733
+ r = mean + jnp.einsum('...ij,...j->...i', factor, normal_samples)
734
+ return r if unit.is_unitless else u.Quantity(r, unit=unit)
735
+
736
+ def rayleigh(self,
737
+ scale=1.0,
883
738
  size: Optional[Size] = None,
884
739
  key: Optional[SeedOrKey] = None,
885
740
  dtype: DTypeLike = None):
886
- dtype = dtype or environ.dftype()
887
- key = self.split_key() if key is None else _formalize_key(key)
888
- size = _size2shape(size)
889
- _check_shape("orthogonal", size)
890
- n = core.concrete_or_error(index, n, "The error occurred in jax.random.orthogonal()")
891
- z = jr.normal(key, size + (n, n), dtype=dtype)
892
- q, r = jnp.linalg.qr(z)
893
- d = jnp.diagonal(r, 0, -2, -1)
894
- r = q * jnp.expand_dims(d / abs(d), -2)
895
- return r
896
-
897
- def noncentral_chisquare(self,
898
- df,
899
- nonc,
900
- size: Optional[Size] = None,
901
- key: Optional[SeedOrKey] = None,
902
- dtype: DTypeLike = None):
903
- dtype = dtype or environ.dftype()
904
- df = jnp.asarray(_check_py_seq(df), dtype=dtype)
905
- nonc = jnp.asarray(_check_py_seq(nonc), dtype=dtype)
906
- if size is None:
907
- size = lax.broadcast_shapes(jnp.shape(df), jnp.shape(nonc))
908
- size = _size2shape(size)
909
- if key is None:
910
- keys = self.split_keys(3)
911
- else:
912
- keys = jr.split(_formalize_key(key), 3)
913
- i = jr.poisson(keys[0], 0.5 * nonc, shape=size, dtype=environ.ditype())
914
- n = jr.normal(keys[1], shape=size, dtype=dtype) + jnp.sqrt(nonc)
915
- cond = jnp.greater(df, 1.0)
916
- df2 = jnp.where(cond, df - 1.0, df + 2.0 * i)
917
- chi2 = 2.0 * jr.gamma(keys[2], 0.5 * df2, shape=size, dtype=dtype)
918
- r = jnp.where(cond, chi2 + n * n, chi2)
919
- return r
920
-
921
- def loggamma(self,
922
- a,
923
- size: Optional[Size] = None,
924
- key: Optional[SeedOrKey] = None,
925
- dtype: DTypeLike = None):
926
- dtype = dtype or environ.dftype()
927
- key = self.split_key() if key is None else _formalize_key(key)
928
- a = _check_py_seq(a)
929
- if size is None:
930
- size = jnp.shape(a)
931
- r = jr.loggamma(key, a, shape=_size2shape(size), dtype=dtype)
932
- return r
933
-
934
- def categorical(self,
935
- logits,
936
- axis: int = -1,
937
- size: Optional[Size] = None,
938
- key: Optional[SeedOrKey] = None):
939
- key = self.split_key() if key is None else _formalize_key(key)
940
- logits = _check_py_seq(logits)
941
- if size is None:
942
- size = list(jnp.shape(logits))
943
- size.pop(axis)
944
- r = jr.categorical(key, logits, axis=axis, shape=_size2shape(size))
945
- return r
946
-
947
- def zipf(self,
948
- a,
949
- size: Optional[Size] = None,
950
- key: Optional[SeedOrKey] = None,
951
- dtype: DTypeLike = None):
952
- a = _check_py_seq(a)
953
- if size is None:
954
- size = jnp.shape(a)
955
- dtype = dtype or environ.ditype()
956
- r = jax.pure_callback(lambda x: np.random.zipf(x, size).astype(dtype),
957
- jax.ShapeDtypeStruct(size, dtype),
958
- a)
959
- return r
960
-
961
- def power(self,
962
- a,
963
- size: Optional[Size] = None,
964
- key: Optional[SeedOrKey] = None,
965
- dtype: DTypeLike = None):
966
- a = _check_py_seq(a)
967
- if size is None:
968
- size = jnp.shape(a)
969
- size = _size2shape(size)
970
- dtype = dtype or environ.dftype()
971
- r = jax.pure_callback(lambda a: np.random.power(a=a, size=size).astype(dtype),
972
- jax.ShapeDtypeStruct(size, dtype),
973
- a)
974
- return r
975
-
976
- def f(self,
977
- dfnum,
978
- dfden,
979
- size: Optional[Size] = None,
980
- key: Optional[SeedOrKey] = None,
981
- dtype: DTypeLike = None):
982
- dfnum = _check_py_seq(dfnum)
983
- dfden = _check_py_seq(dfden)
984
- if size is None:
985
- size = jnp.broadcast_shapes(jnp.shape(dfnum), jnp.shape(dfden))
986
- size = _size2shape(size)
987
- d = {'dfnum': dfnum, 'dfden': dfden}
988
- dtype = dtype or environ.dftype()
989
- r = jax.pure_callback(lambda dfnum_, dfden_: np.random.f(dfnum=dfnum_,
990
- dfden=dfden_,
991
- size=size).astype(dtype),
992
- jax.ShapeDtypeStruct(size, dtype),
993
- dfnum, dfden)
994
- return r
995
-
996
- def hypergeometric(
997
- self,
998
- ngood,
999
- nbad,
1000
- nsample,
1001
- size: Optional[Size] = None,
1002
- key: Optional[SeedOrKey] = None,
1003
- dtype: DTypeLike = None
1004
- ):
1005
- ngood = _check_py_seq(ngood)
1006
- nbad = _check_py_seq(nbad)
1007
- nsample = _check_py_seq(nsample)
1008
-
1009
- if size is None:
1010
- size = lax.broadcast_shapes(jnp.shape(ngood),
1011
- jnp.shape(nbad),
1012
- jnp.shape(nsample))
1013
- size = _size2shape(size)
1014
- dtype = dtype or environ.ditype()
1015
- d = {'ngood': ngood, 'nbad': nbad, 'nsample': nsample}
1016
- r = jax.pure_callback(lambda d: np.random.hypergeometric(ngood=d['ngood'],
1017
- nbad=d['nbad'],
1018
- nsample=d['nsample'],
1019
- size=size).astype(dtype),
1020
- jax.ShapeDtypeStruct(size, dtype),
1021
- d)
1022
- return r
1023
-
1024
- def logseries(self,
1025
- p,
741
+ scale = _check_py_seq(scale)
742
+ if size is None:
743
+ size = jnp.shape(scale)
744
+ key = self.split_key() if key is None else _formalize_key(key)
745
+ dtype = dtype or environ.dftype()
746
+ x = jnp.sqrt(-2. * jnp.log(uniform_for_unit(key, shape=_size2shape(size), minval=0, maxval=1, dtype=dtype)))
747
+ r = x * scale
748
+ return r
749
+
750
+ def triangular(self,
751
+ size: Optional[Size] = None,
752
+ key: Optional[SeedOrKey] = None):
753
+ key = self.split_key() if key is None else _formalize_key(key)
754
+ bernoulli_samples = jr.bernoulli(key, p=0.5, shape=_size2shape(size))
755
+ r = 2 * bernoulli_samples - 1
756
+ return r
757
+
758
+ def vonmises(self,
759
+ mu,
760
+ kappa,
761
+ size: Optional[Size] = None,
762
+ key: Optional[SeedOrKey] = None,
763
+ dtype: DTypeLike = None):
764
+ key = self.split_key() if key is None else _formalize_key(key)
765
+ dtype = dtype or environ.dftype()
766
+ mu = jnp.asarray(_check_py_seq(mu), dtype=dtype)
767
+ kappa = jnp.asarray(_check_py_seq(kappa), dtype=dtype)
768
+ if size is None:
769
+ size = lax.broadcast_shapes(jnp.shape(mu), jnp.shape(kappa))
770
+ size = _size2shape(size)
771
+ samples = _von_mises_centered(key, kappa, size, dtype=dtype)
772
+ samples = samples + mu
773
+ samples = (samples + jnp.pi) % (2.0 * jnp.pi) - jnp.pi
774
+ return samples
775
+
776
+ def weibull(self,
777
+ a,
1026
778
  size: Optional[Size] = None,
1027
779
  key: Optional[SeedOrKey] = None,
1028
780
  dtype: DTypeLike = None):
1029
- p = _check_py_seq(p)
1030
- if size is None:
1031
- size = jnp.shape(p)
1032
- size = _size2shape(size)
1033
- dtype = dtype or environ.ditype()
1034
- r = jax.pure_callback(lambda p: np.random.logseries(p=p, size=size).astype(dtype),
1035
- jax.ShapeDtypeStruct(size, dtype),
1036
- p)
1037
- return r
1038
-
1039
- def noncentral_f(self,
1040
- dfnum,
1041
- dfden,
1042
- nonc,
781
+ key = self.split_key() if key is None else _formalize_key(key)
782
+ a = _check_py_seq(a)
783
+ if size is None:
784
+ size = jnp.shape(a)
785
+ else:
786
+ if jnp.size(a) > 1:
787
+ raise ValueError(f'"a" should be a scalar when "size" is provided. But we got {a}')
788
+ size = _size2shape(size)
789
+ dtype = dtype or environ.dftype()
790
+ random_uniform = uniform_for_unit(key=key, shape=size, minval=0, maxval=1, dtype=dtype)
791
+ r = jnp.power(-jnp.log1p(-random_uniform), 1.0 / a)
792
+ return r
793
+
794
+ def weibull_min(self,
795
+ a,
796
+ scale=None,
797
+ size: Optional[Size] = None,
798
+ key: Optional[SeedOrKey] = None,
799
+ dtype: DTypeLike = None):
800
+ key = self.split_key() if key is None else _formalize_key(key)
801
+ a = _check_py_seq(a)
802
+ scale = _check_py_seq(scale)
803
+ if size is None:
804
+ size = jnp.broadcast_shapes(jnp.shape(a), jnp.shape(scale))
805
+ else:
806
+ if jnp.size(a) > 1:
807
+ raise ValueError(f'"a" should be a scalar when "size" is provided. But we got {a}')
808
+ size = _size2shape(size)
809
+ dtype = dtype or environ.dftype()
810
+ random_uniform = uniform_for_unit(key=key, shape=size, minval=0, maxval=1, dtype=dtype)
811
+ r = jnp.power(-jnp.log1p(-random_uniform), 1.0 / a)
812
+ if scale is not None:
813
+ r /= scale
814
+ return r
815
+
816
+ def maxwell(self,
817
+ size: Optional[Size] = None,
818
+ key: Optional[SeedOrKey] = None,
819
+ dtype: DTypeLike = None):
820
+ key = self.split_key() if key is None else _formalize_key(key)
821
+ shape = _size2shape(size) + (3,)
822
+ dtype = dtype or environ.dftype()
823
+ norm_rvs = jr.normal(key=key, shape=shape, dtype=dtype)
824
+ r = jnp.linalg.norm(norm_rvs, axis=-1)
825
+ return r
826
+
827
+ def negative_binomial(self,
828
+ n,
829
+ p,
830
+ size: Optional[Size] = None,
831
+ key: Optional[SeedOrKey] = None,
832
+ dtype: DTypeLike = None):
833
+ n = _check_py_seq(n)
834
+ p = _check_py_seq(p)
835
+ if size is None:
836
+ size = lax.broadcast_shapes(jnp.shape(n), jnp.shape(p))
837
+ size = _size2shape(size)
838
+ logits = jnp.log(p) - jnp.log1p(-p)
839
+ if key is None:
840
+ keys = self.split_key(2)
841
+ else:
842
+ keys = jr.split(_formalize_key(key), 2)
843
+ rate = self.gamma(shape=n, scale=jnp.exp(-logits), size=size, key=keys[0], dtype=environ.dftype())
844
+ r = self.poisson(lam=rate, key=keys[1], dtype=dtype or environ.ditype())
845
+ return r
846
+
847
+ def wald(self,
848
+ mean,
849
+ scale,
850
+ size: Optional[Size] = None,
851
+ key: Optional[SeedOrKey] = None,
852
+ dtype: DTypeLike = None):
853
+ dtype = dtype or environ.dftype()
854
+ key = self.split_key() if key is None else _formalize_key(key)
855
+ mean = jnp.asarray(_check_py_seq(mean), dtype=dtype)
856
+ scale = jnp.asarray(_check_py_seq(scale), dtype=dtype)
857
+ if size is None:
858
+ size = lax.broadcast_shapes(jnp.shape(mean), jnp.shape(scale))
859
+ size = _size2shape(size)
860
+ sampled_chi2 = jnp.square(self.randn(*size))
861
+ sampled_uniform = self.uniform(size=size, key=key, dtype=dtype)
862
+ # Wikipedia defines an intermediate x with the formula
863
+ # x = loc + loc ** 2 * y / (2 * conc) - loc / (2 * conc) * sqrt(4 * loc * conc * y + loc ** 2 * y ** 2)
864
+ # where y ~ N(0, 1)**2 (sampled_chi2 above) and conc is the concentration.
865
+ # Let us write
866
+ # w = loc * y / (2 * conc)
867
+ # Then we can extract the common factor in the last two terms to obtain
868
+ # x = loc + loc * w * (1 - sqrt(2 / w + 1))
869
+ # Now we see that the Wikipedia formula suffers from catastrphic
870
+ # cancellation for large w (e.g., if conc << loc).
871
+ #
872
+ # Fortunately, we can fix this by multiplying both sides
873
+ # by 1 + sqrt(2 / w + 1). We get
874
+ # x * (1 + sqrt(2 / w + 1)) =
875
+ # = loc * (1 + sqrt(2 / w + 1)) + loc * w * (1 - (2 / w + 1))
876
+ # = loc * (sqrt(2 / w + 1) - 1)
877
+ # The term sqrt(2 / w + 1) + 1 no longer presents numerical
878
+ # difficulties for large w, and sqrt(2 / w + 1) - 1 is just
879
+ # sqrt1pm1(2 / w), which we know how to compute accurately.
880
+ # This just leaves the matter of small w, where 2 / w may
881
+ # overflow. In the limit a w -> 0, x -> loc, so we just mask
882
+ # that case.
883
+ sqrt1pm1_arg = 4 * scale / (mean * sampled_chi2) # 2 / w above
884
+ safe_sqrt1pm1_arg = jnp.where(sqrt1pm1_arg < np.inf, sqrt1pm1_arg, 1.0)
885
+ denominator = 1.0 + jnp.sqrt(safe_sqrt1pm1_arg + 1.0)
886
+ ratio = jnp.expm1(0.5 * jnp.log1p(safe_sqrt1pm1_arg)) / denominator
887
+ sampled = mean * jnp.where(sqrt1pm1_arg < np.inf, ratio, 1.0) # x above
888
+ res = jnp.where(sampled_uniform <= mean / (mean + sampled),
889
+ sampled,
890
+ jnp.square(mean) / sampled)
891
+ return res
892
+
893
+ def t(self,
894
+ df,
895
+ size: Optional[Size] = None,
896
+ key: Optional[SeedOrKey] = None,
897
+ dtype: DTypeLike = None):
898
+ dtype = dtype or environ.dftype()
899
+ df = jnp.asarray(_check_py_seq(df), dtype=dtype)
900
+ if size is None:
901
+ size = np.shape(df)
902
+ else:
903
+ size = _size2shape(size)
904
+ _check_shape("t", size, np.shape(df))
905
+ if key is None:
906
+ keys = self.split_key(2)
907
+ else:
908
+ keys = jr.split(_formalize_key(key), 2)
909
+ n = jr.normal(keys[0], size, dtype=dtype)
910
+ two = _const(n, 2)
911
+ half_df = lax.div(df, two)
912
+ g = jr.gamma(keys[1], half_df, size, dtype=dtype)
913
+ r = n * jnp.sqrt(half_df / g)
914
+ return r
915
+
916
+ def orthogonal(self,
917
+ n: int,
1043
918
  size: Optional[Size] = None,
1044
919
  key: Optional[SeedOrKey] = None,
1045
920
  dtype: DTypeLike = None):
1046
- dfnum = _check_py_seq(dfnum)
1047
- dfden = _check_py_seq(dfden)
1048
- nonc = _check_py_seq(nonc)
1049
- if size is None:
1050
- size = lax.broadcast_shapes(jnp.shape(dfnum),
1051
- jnp.shape(dfden),
1052
- jnp.shape(nonc))
1053
- size = _size2shape(size)
1054
- d = {'dfnum': dfnum, 'dfden': dfden, 'nonc': nonc}
1055
- dtype = dtype or environ.dftype()
1056
- r = jax.pure_callback(lambda x: np.random.noncentral_f(dfnum=x['dfnum'],
1057
- dfden=x['dfden'],
1058
- nonc=x['nonc'],
1059
- size=size).astype(dtype),
1060
- jax.ShapeDtypeStruct(size, dtype),
1061
- d)
1062
- return r
1063
-
1064
- # PyTorch compatibility #
1065
- # --------------------- #
1066
-
1067
- def rand_like(self, input, *, dtype=None, key: Optional[SeedOrKey] = None):
1068
- """Returns a tensor with the same size as input that is filled with random
1069
- numbers from a uniform distribution on the interval ``[0, 1)``.
1070
-
1071
- Args:
1072
- input: the ``size`` of input will determine size of the output tensor.
1073
- dtype: the desired data type of returned Tensor. Default: if ``None``, defaults to the dtype of input.
1074
- key: the seed or key for the random.
1075
-
1076
- Returns:
1077
- The random data.
1078
- """
1079
- return self.random(jnp.shape(input), key=key).astype(dtype)
1080
-
1081
- def randn_like(self, input, *, dtype=None, key: Optional[SeedOrKey] = None):
1082
- """Returns a tensor with the same size as ``input`` that is filled with
1083
- random numbers from a normal distribution with mean 0 and variance 1.
1084
-
1085
- Args:
1086
- input: the ``size`` of input will determine size of the output tensor.
1087
- dtype: the desired data type of returned Tensor. Default: if ``None``, defaults to the dtype of input.
1088
- key: the seed or key for the random.
1089
-
1090
- Returns:
1091
- The random data.
1092
- """
1093
- return self.randn(*jnp.shape(input), key=key).astype(dtype)
1094
-
1095
- def randint_like(self, input, low=0, high=None, *, dtype=None, key: Optional[SeedOrKey] = None):
1096
- if high is None:
1097
- high = max(input)
1098
- return self.randint(low, high=high, size=jnp.shape(input), dtype=dtype, key=key)
921
+ dtype = dtype or environ.dftype()
922
+ key = self.split_key() if key is None else _formalize_key(key)
923
+ size = _size2shape(size)
924
+ _check_shape("orthogonal", size)
925
+ n = core.concrete_or_error(index, n, "The error occurred in jax.random.orthogonal()")
926
+ z = jr.normal(key, size + (n, n), dtype=dtype)
927
+ q, r = jnp.linalg.qr(z)
928
+ d = jnp.diagonal(r, 0, -2, -1)
929
+ r = q * jnp.expand_dims(d / abs(d), -2)
930
+ return r
931
+
932
+ def noncentral_chisquare(self,
933
+ df,
934
+ nonc,
935
+ size: Optional[Size] = None,
936
+ key: Optional[SeedOrKey] = None,
937
+ dtype: DTypeLike = None):
938
+ dtype = dtype or environ.dftype()
939
+ df = jnp.asarray(_check_py_seq(df), dtype=dtype)
940
+ nonc = jnp.asarray(_check_py_seq(nonc), dtype=dtype)
941
+ if size is None:
942
+ size = lax.broadcast_shapes(jnp.shape(df), jnp.shape(nonc))
943
+ size = _size2shape(size)
944
+ if key is None:
945
+ keys = self.split_key(3)
946
+ else:
947
+ keys = jr.split(_formalize_key(key), 3)
948
+ i = jr.poisson(keys[0], 0.5 * nonc, shape=size, dtype=environ.ditype())
949
+ n = jr.normal(keys[1], shape=size, dtype=dtype) + jnp.sqrt(nonc)
950
+ cond = jnp.greater(df, 1.0)
951
+ df2 = jnp.where(cond, df - 1.0, df + 2.0 * i)
952
+ chi2 = 2.0 * jr.gamma(keys[2], 0.5 * df2, shape=size, dtype=dtype)
953
+ r = jnp.where(cond, chi2 + n * n, chi2)
954
+ return r
955
+
956
+ def loggamma(self,
957
+ a,
958
+ size: Optional[Size] = None,
959
+ key: Optional[SeedOrKey] = None,
960
+ dtype: DTypeLike = None):
961
+ dtype = dtype or environ.dftype()
962
+ key = self.split_key() if key is None else _formalize_key(key)
963
+ a = _check_py_seq(a)
964
+ if size is None:
965
+ size = jnp.shape(a)
966
+ r = jr.loggamma(key, a, shape=_size2shape(size), dtype=dtype)
967
+ return r
968
+
969
+ def categorical(self,
970
+ logits,
971
+ axis: int = -1,
972
+ size: Optional[Size] = None,
973
+ key: Optional[SeedOrKey] = None):
974
+ key = self.split_key() if key is None else _formalize_key(key)
975
+ logits = _check_py_seq(logits)
976
+ if size is None:
977
+ size = list(jnp.shape(logits))
978
+ size.pop(axis)
979
+ r = jr.categorical(key, logits, axis=axis, shape=_size2shape(size))
980
+ return r
981
+
982
+ def zipf(self,
983
+ a,
984
+ size: Optional[Size] = None,
985
+ key: Optional[SeedOrKey] = None,
986
+ dtype: DTypeLike = None):
987
+ a = _check_py_seq(a)
988
+ if size is None:
989
+ size = jnp.shape(a)
990
+ dtype = dtype or environ.ditype()
991
+ r = jax.pure_callback(lambda x: np.random.zipf(x, size).astype(dtype),
992
+ jax.ShapeDtypeStruct(size, dtype),
993
+ a)
994
+ return r
995
+
996
+ def power(self,
997
+ a,
998
+ size: Optional[Size] = None,
999
+ key: Optional[SeedOrKey] = None,
1000
+ dtype: DTypeLike = None):
1001
+ a = _check_py_seq(a)
1002
+ if size is None:
1003
+ size = jnp.shape(a)
1004
+ size = _size2shape(size)
1005
+ dtype = dtype or environ.dftype()
1006
+ r = jax.pure_callback(lambda a: np.random.power(a=a, size=size).astype(dtype),
1007
+ jax.ShapeDtypeStruct(size, dtype),
1008
+ a)
1009
+ return r
1010
+
1011
+ def f(self,
1012
+ dfnum,
1013
+ dfden,
1014
+ size: Optional[Size] = None,
1015
+ key: Optional[SeedOrKey] = None,
1016
+ dtype: DTypeLike = None):
1017
+ dfnum = _check_py_seq(dfnum)
1018
+ dfden = _check_py_seq(dfden)
1019
+ if size is None:
1020
+ size = jnp.broadcast_shapes(jnp.shape(dfnum), jnp.shape(dfden))
1021
+ size = _size2shape(size)
1022
+ d = {'dfnum': dfnum, 'dfden': dfden}
1023
+ dtype = dtype or environ.dftype()
1024
+ r = jax.pure_callback(lambda dfnum_, dfden_: np.random.f(dfnum=dfnum_,
1025
+ dfden=dfden_,
1026
+ size=size).astype(dtype),
1027
+ jax.ShapeDtypeStruct(size, dtype),
1028
+ dfnum, dfden)
1029
+ return r
1030
+
1031
+ def hypergeometric(
1032
+ self,
1033
+ ngood,
1034
+ nbad,
1035
+ nsample,
1036
+ size: Optional[Size] = None,
1037
+ key: Optional[SeedOrKey] = None,
1038
+ dtype: DTypeLike = None
1039
+ ):
1040
+ ngood = _check_py_seq(ngood)
1041
+ nbad = _check_py_seq(nbad)
1042
+ nsample = _check_py_seq(nsample)
1043
+
1044
+ if size is None:
1045
+ size = lax.broadcast_shapes(jnp.shape(ngood),
1046
+ jnp.shape(nbad),
1047
+ jnp.shape(nsample))
1048
+ size = _size2shape(size)
1049
+ dtype = dtype or environ.ditype()
1050
+ d = {'ngood': ngood, 'nbad': nbad, 'nsample': nsample}
1051
+ r = jax.pure_callback(lambda d: np.random.hypergeometric(ngood=d['ngood'],
1052
+ nbad=d['nbad'],
1053
+ nsample=d['nsample'],
1054
+ size=size).astype(dtype),
1055
+ jax.ShapeDtypeStruct(size, dtype),
1056
+ d)
1057
+ return r
1058
+
1059
+ def logseries(self,
1060
+ p,
1061
+ size: Optional[Size] = None,
1062
+ key: Optional[SeedOrKey] = None,
1063
+ dtype: DTypeLike = None):
1064
+ p = _check_py_seq(p)
1065
+ if size is None:
1066
+ size = jnp.shape(p)
1067
+ size = _size2shape(size)
1068
+ dtype = dtype or environ.ditype()
1069
+ r = jax.pure_callback(lambda p: np.random.logseries(p=p, size=size).astype(dtype),
1070
+ jax.ShapeDtypeStruct(size, dtype),
1071
+ p)
1072
+ return r
1073
+
1074
+ def noncentral_f(self,
1075
+ dfnum,
1076
+ dfden,
1077
+ nonc,
1078
+ size: Optional[Size] = None,
1079
+ key: Optional[SeedOrKey] = None,
1080
+ dtype: DTypeLike = None):
1081
+ dfnum = _check_py_seq(dfnum)
1082
+ dfden = _check_py_seq(dfden)
1083
+ nonc = _check_py_seq(nonc)
1084
+ if size is None:
1085
+ size = lax.broadcast_shapes(jnp.shape(dfnum),
1086
+ jnp.shape(dfden),
1087
+ jnp.shape(nonc))
1088
+ size = _size2shape(size)
1089
+ d = {'dfnum': dfnum, 'dfden': dfden, 'nonc': nonc}
1090
+ dtype = dtype or environ.dftype()
1091
+ r = jax.pure_callback(lambda x: np.random.noncentral_f(dfnum=x['dfnum'],
1092
+ dfden=x['dfden'],
1093
+ nonc=x['nonc'],
1094
+ size=size).astype(dtype),
1095
+ jax.ShapeDtypeStruct(size, dtype),
1096
+ d)
1097
+ return r
1098
+
1099
+ # PyTorch compatibility #
1100
+ # --------------------- #
1101
+
1102
+ def rand_like(self, input, *, dtype=None, key: Optional[SeedOrKey] = None):
1103
+ """Returns a tensor with the same size as input that is filled with random
1104
+ numbers from a uniform distribution on the interval ``[0, 1)``.
1105
+
1106
+ Args:
1107
+ input: the ``size`` of input will determine size of the output tensor.
1108
+ dtype: the desired data type of returned Tensor. Default: if ``None``, defaults to the dtype of input.
1109
+ key: the seed or key for the random.
1110
+
1111
+ Returns:
1112
+ The random data.
1113
+ """
1114
+ return self.random(jnp.shape(input), key=key).astype(dtype)
1115
+
1116
+ def randn_like(self, input, *, dtype=None, key: Optional[SeedOrKey] = None):
1117
+ """Returns a tensor with the same size as ``input`` that is filled with
1118
+ random numbers from a normal distribution with mean 0 and variance 1.
1119
+
1120
+ Args:
1121
+ input: the ``size`` of input will determine size of the output tensor.
1122
+ dtype: the desired data type of returned Tensor. Default: if ``None``, defaults to the dtype of input.
1123
+ key: the seed or key for the random.
1124
+
1125
+ Returns:
1126
+ The random data.
1127
+ """
1128
+ return self.randn(*jnp.shape(input), key=key).astype(dtype)
1129
+
1130
+ def randint_like(self, input, low=0, high=None, *, dtype=None, key: Optional[SeedOrKey] = None):
1131
+ if high is None:
1132
+ high = max(input)
1133
+ return self.randint(low, high=high, size=jnp.shape(input), dtype=dtype, key=key)
1099
1134
 
1100
1135
 
1101
1136
  # default random generator
@@ -1106,393 +1141,393 @@ DEFAULT = RandomState(np.random.randint(0, 10000, size=2, dtype=np.uint32))
1106
1141
 
1107
1142
 
1108
1143
  def _formalize_key(key):
1109
- if isinstance(key, int):
1110
- return jr.PRNGKey(key)
1111
- elif isinstance(key, (jax.Array, np.ndarray)):
1112
- if key.dtype != jnp.uint32:
1113
- raise TypeError('key must be a int or an array with two uint32.')
1114
- if key.size != 2:
1115
- raise TypeError('key must be a int or an array with two uint32.')
1116
- return jnp.asarray(key, dtype=jnp.uint32)
1117
- else:
1118
- raise TypeError('key must be a int or an array with two uint32.')
1144
+ if isinstance(key, int):
1145
+ return jr.PRNGKey(key)
1146
+ elif isinstance(key, (jax.Array, np.ndarray)):
1147
+ if key.dtype != jnp.uint32:
1148
+ raise TypeError('key must be a int or an array with two uint32.')
1149
+ if key.size != 2:
1150
+ raise TypeError('key must be a int or an array with two uint32.')
1151
+ return jnp.asarray(key, dtype=jnp.uint32)
1152
+ else:
1153
+ raise TypeError('key must be a int or an array with two uint32.')
1119
1154
 
1120
1155
 
1121
1156
  def _size2shape(size):
1122
- if size is None:
1123
- return ()
1124
- elif isinstance(size, (tuple, list)):
1125
- return tuple(size)
1126
- else:
1127
- return (size,)
1157
+ if size is None:
1158
+ return ()
1159
+ elif isinstance(size, (tuple, list)):
1160
+ return tuple(size)
1161
+ else:
1162
+ return (size,)
1128
1163
 
1129
1164
 
1130
1165
  def _check_shape(name, shape, *param_shapes):
1131
- if param_shapes:
1132
- shape_ = lax.broadcast_shapes(shape, *param_shapes)
1133
- if shape != shape_:
1134
- msg = ("{} parameter shapes must be broadcast-compatible with shape "
1135
- "argument, and the result of broadcasting the shapes must equal "
1136
- "the shape argument, but got result {} for shape argument {}.")
1137
- raise ValueError(msg.format(name, shape_, shape))
1166
+ if param_shapes:
1167
+ shape_ = lax.broadcast_shapes(shape, *param_shapes)
1168
+ if shape != shape_:
1169
+ msg = ("{} parameter shapes must be broadcast-compatible with shape "
1170
+ "argument, and the result of broadcasting the shapes must equal "
1171
+ "the shape argument, but got result {} for shape argument {}.")
1172
+ raise ValueError(msg.format(name, shape_, shape))
1138
1173
 
1139
1174
 
1140
1175
  def _is_python_scalar(x):
1141
- if hasattr(x, 'aval'):
1142
- return x.aval.weak_type
1143
- elif np.ndim(x) == 0:
1144
- return True
1145
- elif isinstance(x, (bool, int, float, complex)):
1146
- return True
1147
- else:
1148
- return False
1176
+ if hasattr(x, 'aval'):
1177
+ return x.aval.weak_type
1178
+ elif np.ndim(x) == 0:
1179
+ return True
1180
+ elif isinstance(x, (bool, int, float, complex)):
1181
+ return True
1182
+ else:
1183
+ return False
1149
1184
 
1150
1185
 
1151
1186
  python_scalar_dtypes = {
1152
- bool: np.dtype('bool'),
1153
- int: np.dtype('int64'),
1154
- float: np.dtype('float64'),
1155
- complex: np.dtype('complex128'),
1187
+ bool: np.dtype('bool'),
1188
+ int: np.dtype('int64'),
1189
+ float: np.dtype('float64'),
1190
+ complex: np.dtype('complex128'),
1156
1191
  }
1157
1192
 
1158
1193
 
1159
1194
  def _dtype(x, *, canonicalize: bool = False):
1160
- """Return the dtype object for a value or type, optionally canonicalized based on X64 mode."""
1161
- if x is None:
1162
- raise ValueError(f"Invalid argument to dtype: {x}.")
1163
- elif isinstance(x, type) and x in python_scalar_dtypes:
1164
- dt = python_scalar_dtypes[x]
1165
- elif type(x) in python_scalar_dtypes:
1166
- dt = python_scalar_dtypes[type(x)]
1167
- elif hasattr(x, 'dtype'):
1168
- dt = x.dtype
1169
- else:
1170
- dt = np.result_type(x)
1171
- return dtypes.canonicalize_dtype(dt) if canonicalize else dt
1195
+ """Return the dtype object for a value or type, optionally canonicalized based on X64 mode."""
1196
+ if x is None:
1197
+ raise ValueError(f"Invalid argument to dtype: {x}.")
1198
+ elif isinstance(x, type) and x in python_scalar_dtypes:
1199
+ dt = python_scalar_dtypes[x]
1200
+ elif type(x) in python_scalar_dtypes:
1201
+ dt = python_scalar_dtypes[type(x)]
1202
+ elif hasattr(x, 'dtype'):
1203
+ dt = x.dtype
1204
+ else:
1205
+ dt = np.result_type(x)
1206
+ return dtypes.canonicalize_dtype(dt) if canonicalize else dt
1172
1207
 
1173
1208
 
1174
1209
  def _const(example, val):
1175
- if _is_python_scalar(example):
1176
- dtype = dtypes.canonicalize_dtype(type(example))
1177
- val = dtypes.scalar_type_of(example)(val)
1178
- return val if dtype == _dtype(val, canonicalize=True) else np.array(val, dtype)
1179
- else:
1180
- dtype = dtypes.canonicalize_dtype(example.dtype)
1181
- return np.array(val, dtype)
1210
+ if _is_python_scalar(example):
1211
+ dtype = dtypes.canonicalize_dtype(type(example))
1212
+ val = dtypes.scalar_type_of(example)(val)
1213
+ return val if dtype == _dtype(val, canonicalize=True) else np.array(val, dtype)
1214
+ else:
1215
+ dtype = dtypes.canonicalize_dtype(example.dtype)
1216
+ return np.array(val, dtype)
1182
1217
 
1183
1218
 
1184
1219
  _tr_params = namedtuple(
1185
- "tr_params", ["c", "b", "a", "alpha", "u_r", "v_r", "m", "log_p", "log1_p", "log_h"]
1220
+ "tr_params", ["c", "b", "a", "alpha", "u_r", "v_r", "m", "log_p", "log1_p", "log_h"]
1186
1221
  )
1187
1222
 
1188
1223
 
1189
1224
  def _get_tr_params(n, p):
1190
- # See Table 1. Additionally, we pre-compute log(p), log1(-p) and the
1191
- # constant terms, that depend only on (n, p, m) in log(f(k)) (bottom of page 5).
1192
- mu = n * p
1193
- spq = jnp.sqrt(mu * (1 - p))
1194
- c = mu + 0.5
1195
- b = 1.15 + 2.53 * spq
1196
- a = -0.0873 + 0.0248 * b + 0.01 * p
1197
- alpha = (2.83 + 5.1 / b) * spq
1198
- u_r = 0.43
1199
- v_r = 0.92 - 4.2 / b
1200
- m = jnp.floor((n + 1) * p).astype(n.dtype)
1201
- log_p = jnp.log(p)
1202
- log1_p = jnp.log1p(-p)
1203
- log_h = ((m + 0.5) * (jnp.log((m + 1.0) / (n - m + 1.0)) + log1_p - log_p) +
1204
- _stirling_approx_tail(m) + _stirling_approx_tail(n - m))
1205
- return _tr_params(c, b, a, alpha, u_r, v_r, m, log_p, log1_p, log_h)
1225
+ # See Table 1. Additionally, we pre-compute log(p), log1(-p) and the
1226
+ # constant terms, that depend only on (n, p, m) in log(f(k)) (bottom of page 5).
1227
+ mu = n * p
1228
+ spq = jnp.sqrt(mu * (1 - p))
1229
+ c = mu + 0.5
1230
+ b = 1.15 + 2.53 * spq
1231
+ a = -0.0873 + 0.0248 * b + 0.01 * p
1232
+ alpha = (2.83 + 5.1 / b) * spq
1233
+ u_r = 0.43
1234
+ v_r = 0.92 - 4.2 / b
1235
+ m = jnp.floor((n + 1) * p).astype(n.dtype)
1236
+ log_p = jnp.log(p)
1237
+ log1_p = jnp.log1p(-p)
1238
+ log_h = ((m + 0.5) * (jnp.log((m + 1.0) / (n - m + 1.0)) + log1_p - log_p) +
1239
+ _stirling_approx_tail(m) + _stirling_approx_tail(n - m))
1240
+ return _tr_params(c, b, a, alpha, u_r, v_r, m, log_p, log1_p, log_h)
1206
1241
 
1207
1242
 
1208
1243
  def _stirling_approx_tail(k):
1209
- precomputed = jnp.array([0.08106146679532726,
1210
- 0.04134069595540929,
1211
- 0.02767792568499834,
1212
- 0.02079067210376509,
1213
- 0.01664469118982119,
1214
- 0.01387612882307075,
1215
- 0.01189670994589177,
1216
- 0.01041126526197209,
1217
- 0.009255462182712733,
1218
- 0.008330563433362871],
1219
- dtype=environ.dftype())
1220
- kp1 = k + 1
1221
- kp1sq = (k + 1) ** 2
1222
- return jnp.where(k < 10,
1223
- precomputed[k],
1224
- (1.0 / 12 - (1.0 / 360 - (1.0 / 1260) / kp1sq) / kp1sq) / kp1)
1244
+ precomputed = jnp.array([0.08106146679532726,
1245
+ 0.04134069595540929,
1246
+ 0.02767792568499834,
1247
+ 0.02079067210376509,
1248
+ 0.01664469118982119,
1249
+ 0.01387612882307075,
1250
+ 0.01189670994589177,
1251
+ 0.01041126526197209,
1252
+ 0.009255462182712733,
1253
+ 0.008330563433362871],
1254
+ dtype=environ.dftype())
1255
+ kp1 = k + 1
1256
+ kp1sq = (k + 1) ** 2
1257
+ return jnp.where(k < 10,
1258
+ precomputed[k],
1259
+ (1.0 / 12 - (1.0 / 360 - (1.0 / 1260) / kp1sq) / kp1sq) / kp1)
1225
1260
 
1226
1261
 
1227
1262
  def _binomial_btrs(key, p, n):
1228
- """
1229
- Based on the transformed rejection sampling algorithm (BTRS) from the
1230
- following reference:
1231
-
1232
- Hormann, "The Generation of Binonmial Random Variates"
1233
- (https://core.ac.uk/download/pdf/11007254.pdf)
1234
- """
1235
-
1236
- def _btrs_body_fn(val):
1237
- _, key, _, _ = val
1238
- key, key_u, key_v = jr.split(key, 3)
1239
- u = jr.uniform(key_u)
1240
- v = jr.uniform(key_v)
1241
- u = u - 0.5
1242
- k = jnp.floor(
1243
- (2 * tr_params.a / (0.5 - jnp.abs(u)) + tr_params.b) * u + tr_params.c
1244
- ).astype(n.dtype)
1245
- return k, key, u, v
1246
-
1247
- def _btrs_cond_fn(val):
1248
- def accept_fn(k, u, v):
1249
- # See acceptance condition in Step 3. (Page 3) of TRS algorithm
1250
- # v <= f(k) * g_grad(u) / alpha
1251
-
1252
- m = tr_params.m
1253
- log_p = tr_params.log_p
1254
- log1_p = tr_params.log1_p
1255
- # See: formula for log(f(k)) at bottom of Page 5.
1256
- log_f = (
1257
- (n + 1.0) * jnp.log((n - m + 1.0) / (n - k + 1.0))
1258
- + (k + 0.5) * (jnp.log((n - k + 1.0) / (k + 1.0)) + log_p - log1_p)
1259
- + (_stirling_approx_tail(k) - _stirling_approx_tail(n - k))
1260
- + tr_params.log_h
1261
- )
1262
- g = (tr_params.a / (0.5 - jnp.abs(u)) ** 2) + tr_params.b
1263
- return jnp.log((v * tr_params.alpha) / g) <= log_f
1264
-
1265
- k, key, u, v = val
1266
- early_accept = (jnp.abs(u) <= tr_params.u_r) & (v <= tr_params.v_r)
1267
- early_reject = (k < 0) | (k > n)
1268
- return lax.cond(
1269
- early_accept | early_reject,
1270
- (),
1271
- lambda _: ~early_accept,
1272
- (k, u, v),
1273
- lambda x: ~accept_fn(*x),
1274
- )
1263
+ """
1264
+ Based on the transformed rejection sampling algorithm (BTRS) from the
1265
+ following reference:
1275
1266
 
1276
- tr_params = _get_tr_params(n, p)
1277
- ret = lax.while_loop(
1278
- _btrs_cond_fn, _btrs_body_fn, (-1, key, 1.0, 1.0)
1279
- ) # use k=-1 initially so that cond_fn returns True
1280
- return ret[0]
1267
+ Hormann, "The Generation of Binonmial Random Variates"
1268
+ (https://core.ac.uk/download/pdf/11007254.pdf)
1269
+ """
1270
+
1271
+ def _btrs_body_fn(val):
1272
+ _, key, _, _ = val
1273
+ key, key_u, key_v = jr.split(key, 3)
1274
+ u = jr.uniform(key_u)
1275
+ v = jr.uniform(key_v)
1276
+ u = u - 0.5
1277
+ k = jnp.floor(
1278
+ (2 * tr_params.a / (0.5 - jnp.abs(u)) + tr_params.b) * u + tr_params.c
1279
+ ).astype(n.dtype)
1280
+ return k, key, u, v
1281
+
1282
+ def _btrs_cond_fn(val):
1283
+ def accept_fn(k, u, v):
1284
+ # See acceptance condition in Step 3. (Page 3) of TRS algorithm
1285
+ # v <= f(k) * g_grad(u) / alpha
1286
+
1287
+ m = tr_params.m
1288
+ log_p = tr_params.log_p
1289
+ log1_p = tr_params.log1_p
1290
+ # See: formula for log(f(k)) at bottom of Page 5.
1291
+ log_f = (
1292
+ (n + 1.0) * jnp.log((n - m + 1.0) / (n - k + 1.0))
1293
+ + (k + 0.5) * (jnp.log((n - k + 1.0) / (k + 1.0)) + log_p - log1_p)
1294
+ + (_stirling_approx_tail(k) - _stirling_approx_tail(n - k))
1295
+ + tr_params.log_h
1296
+ )
1297
+ g = (tr_params.a / (0.5 - jnp.abs(u)) ** 2) + tr_params.b
1298
+ return jnp.log((v * tr_params.alpha) / g) <= log_f
1299
+
1300
+ k, key, u, v = val
1301
+ early_accept = (jnp.abs(u) <= tr_params.u_r) & (v <= tr_params.v_r)
1302
+ early_reject = (k < 0) | (k > n)
1303
+ return lax.cond(
1304
+ early_accept | early_reject,
1305
+ (),
1306
+ lambda _: ~early_accept,
1307
+ (k, u, v),
1308
+ lambda x: ~accept_fn(*x),
1309
+ )
1310
+
1311
+ tr_params = _get_tr_params(n, p)
1312
+ ret = lax.while_loop(
1313
+ _btrs_cond_fn, _btrs_body_fn, (-1, key, 1.0, 1.0)
1314
+ ) # use k=-1 initially so that cond_fn returns True
1315
+ return ret[0]
1281
1316
 
1282
1317
 
1283
1318
  def _binomial_inversion(key, p, n):
1284
- def _binom_inv_body_fn(val):
1285
- i, key, geom_acc = val
1286
- key, key_u = jr.split(key)
1287
- u = jr.uniform(key_u)
1288
- geom = jnp.floor(jnp.log1p(-u) / log1_p) + 1
1289
- geom_acc = geom_acc + geom
1290
- return i + 1, key, geom_acc
1319
+ def _binom_inv_body_fn(val):
1320
+ i, key, geom_acc = val
1321
+ key, key_u = jr.split(key)
1322
+ u = jr.uniform(key_u)
1323
+ geom = jnp.floor(jnp.log1p(-u) / log1_p) + 1
1324
+ geom_acc = geom_acc + geom
1325
+ return i + 1, key, geom_acc
1291
1326
 
1292
- def _binom_inv_cond_fn(val):
1293
- i, _, geom_acc = val
1294
- return geom_acc <= n
1327
+ def _binom_inv_cond_fn(val):
1328
+ i, _, geom_acc = val
1329
+ return geom_acc <= n
1295
1330
 
1296
- log1_p = jnp.log1p(-p)
1297
- ret = lax.while_loop(_binom_inv_cond_fn, _binom_inv_body_fn, (-1, key, 0.0))
1298
- return ret[0]
1331
+ log1_p = jnp.log1p(-p)
1332
+ ret = lax.while_loop(_binom_inv_cond_fn, _binom_inv_body_fn, (-1, key, 0.0))
1333
+ return ret[0]
1299
1334
 
1300
1335
 
1301
1336
  def _binomial_dispatch(key, p, n):
1302
- def dispatch(key, p, n):
1303
- is_le_mid = p <= 0.5
1304
- pq = jnp.where(is_le_mid, p, 1 - p)
1305
- mu = n * pq
1306
- k = lax.cond(
1307
- mu < 10,
1308
- (key, pq, n),
1309
- lambda x: _binomial_inversion(*x),
1310
- (key, pq, n),
1311
- lambda x: _binomial_btrs(*x),
1337
+ def dispatch(key, p, n):
1338
+ is_le_mid = p <= 0.5
1339
+ pq = jnp.where(is_le_mid, p, 1 - p)
1340
+ mu = n * pq
1341
+ k = lax.cond(
1342
+ mu < 10,
1343
+ (key, pq, n),
1344
+ lambda x: _binomial_inversion(*x),
1345
+ (key, pq, n),
1346
+ lambda x: _binomial_btrs(*x),
1347
+ )
1348
+ return jnp.where(is_le_mid, k, n - k)
1349
+
1350
+ # Return 0 for nan `p` or negative `n`, since nan values are not allowed for integer types
1351
+ cond0 = jnp.isfinite(p) & (n > 0) & (p > 0)
1352
+ return lax.cond(
1353
+ cond0 & (p < 1),
1354
+ (key, p, n),
1355
+ lambda x: dispatch(*x),
1356
+ (),
1357
+ lambda _: jnp.where(cond0, n, 0),
1312
1358
  )
1313
- return jnp.where(is_le_mid, k, n - k)
1314
-
1315
- # Return 0 for nan `p` or negative `n`, since nan values are not allowed for integer types
1316
- cond0 = jnp.isfinite(p) & (n > 0) & (p > 0)
1317
- return lax.cond(
1318
- cond0 & (p < 1),
1319
- (key, p, n),
1320
- lambda x: dispatch(*x),
1321
- (),
1322
- lambda _: jnp.where(cond0, n, 0),
1323
- )
1324
1359
 
1325
1360
 
1326
1361
  @partial(jit, static_argnums=(3,))
1327
1362
  def _binomial(key, p, n, shape):
1328
- shape = shape or lax.broadcast_shapes(jnp.shape(p), jnp.shape(n))
1329
- # reshape to map over axis 0
1330
- p = jnp.reshape(jnp.broadcast_to(p, shape), -1)
1331
- n = jnp.reshape(jnp.broadcast_to(n, shape), -1)
1332
- key = jr.split(key, jnp.size(p))
1333
- if jax.default_backend() == "cpu":
1334
- ret = lax.map(lambda x: _binomial_dispatch(*x), (key, p, n))
1335
- else:
1336
- ret = vmap(lambda *x: _binomial_dispatch(*x))(key, p, n)
1337
- return jnp.reshape(ret, shape)
1363
+ shape = shape or lax.broadcast_shapes(jnp.shape(p), jnp.shape(n))
1364
+ # reshape to map over axis 0
1365
+ p = jnp.reshape(jnp.broadcast_to(p, shape), -1)
1366
+ n = jnp.reshape(jnp.broadcast_to(n, shape), -1)
1367
+ key = jr.split(key, jnp.size(p))
1368
+ if jax.default_backend() == "cpu":
1369
+ ret = lax.map(lambda x: _binomial_dispatch(*x), (key, p, n))
1370
+ else:
1371
+ ret = vmap(lambda *x: _binomial_dispatch(*x))(key, p, n)
1372
+ return jnp.reshape(ret, shape)
1338
1373
 
1339
1374
 
1340
1375
  @partial(jit, static_argnums=(2,))
1341
1376
  def _categorical(key, p, shape):
1342
- # this implementation is fast when event shape is small, and slow otherwise
1343
- # Ref: https://stackoverflow.com/a/34190035
1344
- shape = shape or p.shape[:-1]
1345
- s = jnp.cumsum(p, axis=-1)
1346
- r = jr.uniform(key, shape=shape + (1,))
1347
- return jnp.sum(s < r, axis=-1)
1377
+ # this implementation is fast when event shape is small, and slow otherwise
1378
+ # Ref: https://stackoverflow.com/a/34190035
1379
+ shape = shape or p.shape[:-1]
1380
+ s = jnp.cumsum(p, axis=-1)
1381
+ r = jr.uniform(key, shape=shape + (1,))
1382
+ return jnp.sum(s < r, axis=-1)
1348
1383
 
1349
1384
 
1350
1385
  def _scatter_add_one(operand, indices, updates):
1351
- return lax.scatter_add(
1352
- operand,
1353
- indices,
1354
- updates,
1355
- lax.ScatterDimensionNumbers(
1356
- update_window_dims=(),
1357
- inserted_window_dims=(0,),
1358
- scatter_dims_to_operand_dims=(0,),
1359
- ),
1360
- )
1386
+ return lax.scatter_add(
1387
+ operand,
1388
+ indices,
1389
+ updates,
1390
+ lax.ScatterDimensionNumbers(
1391
+ update_window_dims=(),
1392
+ inserted_window_dims=(0,),
1393
+ scatter_dims_to_operand_dims=(0,),
1394
+ ),
1395
+ )
1361
1396
 
1362
1397
 
1363
1398
  def _reshape(x, shape):
1364
- if isinstance(x, (int, float, np.ndarray, np.generic)):
1365
- return np.reshape(x, shape)
1366
- else:
1367
- return jnp.reshape(x, shape)
1399
+ if isinstance(x, (int, float, np.ndarray, np.generic)):
1400
+ return np.reshape(x, shape)
1401
+ else:
1402
+ return jnp.reshape(x, shape)
1368
1403
 
1369
1404
 
1370
1405
  def _promote_shapes(*args, shape=()):
1371
- # adapted from lax.lax_numpy
1372
- if len(args) < 2 and not shape:
1373
- return args
1374
- else:
1375
- shapes = [jnp.shape(arg) for arg in args]
1376
- num_dims = len(lax.broadcast_shapes(shape, *shapes))
1377
- return [
1378
- _reshape(arg, (1,) * (num_dims - len(s)) + s) if len(s) < num_dims else arg
1379
- for arg, s in zip(args, shapes)
1380
- ]
1406
+ # adapted from lax.lax_numpy
1407
+ if len(args) < 2 and not shape:
1408
+ return args
1409
+ else:
1410
+ shapes = [jnp.shape(arg) for arg in args]
1411
+ num_dims = len(lax.broadcast_shapes(shape, *shapes))
1412
+ return [
1413
+ _reshape(arg, (1,) * (num_dims - len(s)) + s) if len(s) < num_dims else arg
1414
+ for arg, s in zip(args, shapes)
1415
+ ]
1381
1416
 
1382
1417
 
1383
1418
  @partial(jit, static_argnums=(3, 4))
1384
1419
  def _multinomial(key, p, n, n_max, shape=()):
1385
- if jnp.shape(n) != jnp.shape(p)[:-1]:
1386
- broadcast_shape = lax.broadcast_shapes(jnp.shape(n), jnp.shape(p)[:-1])
1387
- n = jnp.broadcast_to(n, broadcast_shape)
1388
- p = jnp.broadcast_to(p, broadcast_shape + jnp.shape(p)[-1:])
1389
- shape = shape or p.shape[:-1]
1390
- if n_max == 0:
1391
- return jnp.zeros(shape + p.shape[-1:], dtype=jnp.result_type(int))
1392
- # get indices from categorical distribution then gather the result
1393
- indices = _categorical(key, p, (n_max,) + shape)
1394
- # mask out values when counts is heterogeneous
1395
- if jnp.ndim(n) > 0:
1396
- mask = _promote_shapes(jnp.arange(n_max) < jnp.expand_dims(n, -1), shape=shape + (n_max,))[0]
1397
- mask = jnp.moveaxis(mask, -1, 0).astype(indices.dtype)
1398
- excess = jnp.concatenate([jnp.expand_dims(n_max - n, -1),
1399
- jnp.zeros(jnp.shape(n) + (p.shape[-1] - 1,))],
1400
- -1)
1401
- else:
1402
- mask = 1
1403
- excess = 0
1404
- # NB: we transpose to move batch shape to the front
1405
- indices_2D = (jnp.reshape(indices * mask, (n_max, -1))).T
1406
- samples_2D = vmap(_scatter_add_one)(jnp.zeros((indices_2D.shape[0], p.shape[-1]), dtype=indices.dtype),
1407
- jnp.expand_dims(indices_2D, axis=-1),
1408
- jnp.ones(indices_2D.shape, dtype=indices.dtype))
1409
- return jnp.reshape(samples_2D, shape + p.shape[-1:]) - excess
1420
+ if jnp.shape(n) != jnp.shape(p)[:-1]:
1421
+ broadcast_shape = lax.broadcast_shapes(jnp.shape(n), jnp.shape(p)[:-1])
1422
+ n = jnp.broadcast_to(n, broadcast_shape)
1423
+ p = jnp.broadcast_to(p, broadcast_shape + jnp.shape(p)[-1:])
1424
+ shape = shape or p.shape[:-1]
1425
+ if n_max == 0:
1426
+ return jnp.zeros(shape + p.shape[-1:], dtype=jnp.result_type(int))
1427
+ # get indices from categorical distribution then gather the result
1428
+ indices = _categorical(key, p, (n_max,) + shape)
1429
+ # mask out values when counts is heterogeneous
1430
+ if jnp.ndim(n) > 0:
1431
+ mask = _promote_shapes(jnp.arange(n_max) < jnp.expand_dims(n, -1), shape=shape + (n_max,))[0]
1432
+ mask = jnp.moveaxis(mask, -1, 0).astype(indices.dtype)
1433
+ excess = jnp.concatenate([jnp.expand_dims(n_max - n, -1),
1434
+ jnp.zeros(jnp.shape(n) + (p.shape[-1] - 1,))],
1435
+ -1)
1436
+ else:
1437
+ mask = 1
1438
+ excess = 0
1439
+ # NB: we transpose to move batch shape to the front
1440
+ indices_2D = (jnp.reshape(indices * mask, (n_max, -1))).T
1441
+ samples_2D = vmap(_scatter_add_one)(jnp.zeros((indices_2D.shape[0], p.shape[-1]), dtype=indices.dtype),
1442
+ jnp.expand_dims(indices_2D, axis=-1),
1443
+ jnp.ones(indices_2D.shape, dtype=indices.dtype))
1444
+ return jnp.reshape(samples_2D, shape + p.shape[-1:]) - excess
1410
1445
 
1411
1446
 
1412
1447
  @partial(jit, static_argnums=(2, 3), static_argnames=['shape', 'dtype'])
1413
1448
  def _von_mises_centered(key, concentration, shape, dtype=None):
1414
- """Compute centered von Mises samples using rejection sampling from [1]_ with wrapped Cauchy proposal.
1415
-
1416
- Returns
1417
- -------
1418
- out: array_like
1419
- centered samples from von Mises
1420
-
1421
- References
1422
- ----------
1423
- .. [1] Luc Devroye "Non-Uniform Random Variate Generation", Springer-Verlag, 1986;
1424
- Chapter 9, p. 473-476. http://www.nrbook.com/devroye/Devroye_files/chapter_nine.pdf
1425
-
1426
- """
1427
- shape = shape or jnp.shape(concentration)
1428
- dtype = dtype or environ.dftype()
1429
- concentration = lax.convert_element_type(concentration, dtype)
1430
- concentration = jnp.broadcast_to(concentration, shape)
1431
-
1432
- if dtype == jnp.float16:
1433
- s_cutoff = 1.8e-1
1434
- elif dtype == jnp.float32:
1435
- s_cutoff = 2e-2
1436
- elif dtype == jnp.float64:
1437
- s_cutoff = 1.2e-4
1438
- else:
1439
- raise ValueError(f"Unsupported dtype: {dtype}")
1440
-
1441
- r = 1.0 + jnp.sqrt(1.0 + 4.0 * concentration ** 2)
1442
- rho = (r - jnp.sqrt(2.0 * r)) / (2.0 * concentration)
1443
- s_exact = (1.0 + rho ** 2) / (2.0 * rho)
1444
-
1445
- s_approximate = 1.0 / concentration
1446
-
1447
- s = jnp.where(concentration > s_cutoff, s_exact, s_approximate)
1448
-
1449
- def cond_fn(*args):
1450
- """check if all are done or reached max number of iterations"""
1451
- i, _, done, _, _ = args[0]
1452
- return jnp.bitwise_and(i < 100, jnp.logical_not(jnp.all(done)))
1453
-
1454
- def body_fn(*args):
1455
- i, key, done, _, w = args[0]
1456
- uni_ukey, uni_vkey, key = jr.split(key, 3)
1457
- u = jr.uniform(
1458
- key=uni_ukey,
1459
- shape=shape,
1460
- dtype=concentration.dtype,
1461
- minval=-1.0,
1462
- maxval=1.0,
1463
- )
1464
- z = jnp.cos(jnp.pi * u)
1465
- w = jnp.where(done, w, (1.0 + s * z) / (s + z)) # Update where not done
1466
- y = concentration * (s - w)
1467
- v = jr.uniform(key=uni_vkey, shape=shape, dtype=concentration.dtype)
1468
- accept = (y * (2.0 - y) >= v) | (jnp.log(y / v) + 1.0 >= y)
1469
- return i + 1, key, accept | done, u, w
1449
+ """Compute centered von Mises samples using rejection sampling from [1]_ with wrapped Cauchy proposal.
1470
1450
 
1471
- init_done = jnp.zeros(shape, dtype=bool)
1472
- init_u = jnp.zeros(shape)
1473
- init_w = jnp.zeros(shape)
1451
+ Returns
1452
+ -------
1453
+ out: array_like
1454
+ centered samples from von Mises
1474
1455
 
1475
- _, _, done, u, w = lax.while_loop(
1476
- cond_fun=cond_fn,
1477
- body_fun=body_fn,
1478
- init_val=(jnp.array(0), key, init_done, init_u, init_w),
1479
- )
1456
+ References
1457
+ ----------
1458
+ .. [1] Luc Devroye "Non-Uniform Random Variate Generation", Springer-Verlag, 1986;
1459
+ Chapter 9, p. 473-476. http://www.nrbook.com/devroye/Devroye_files/chapter_nine.pdf
1480
1460
 
1481
- return jnp.sign(u) * jnp.arccos(w)
1461
+ """
1462
+ shape = shape or jnp.shape(concentration)
1463
+ dtype = dtype or environ.dftype()
1464
+ concentration = lax.convert_element_type(concentration, dtype)
1465
+ concentration = jnp.broadcast_to(concentration, shape)
1466
+
1467
+ if dtype == jnp.float16:
1468
+ s_cutoff = 1.8e-1
1469
+ elif dtype == jnp.float32:
1470
+ s_cutoff = 2e-2
1471
+ elif dtype == jnp.float64:
1472
+ s_cutoff = 1.2e-4
1473
+ else:
1474
+ raise ValueError(f"Unsupported dtype: {dtype}")
1475
+
1476
+ r = 1.0 + jnp.sqrt(1.0 + 4.0 * concentration ** 2)
1477
+ rho = (r - jnp.sqrt(2.0 * r)) / (2.0 * concentration)
1478
+ s_exact = (1.0 + rho ** 2) / (2.0 * rho)
1479
+
1480
+ s_approximate = 1.0 / concentration
1481
+
1482
+ s = jnp.where(concentration > s_cutoff, s_exact, s_approximate)
1483
+
1484
+ def cond_fn(*args):
1485
+ """check if all are done or reached max number of iterations"""
1486
+ i, _, done, _, _ = args[0]
1487
+ return jnp.bitwise_and(i < 100, jnp.logical_not(jnp.all(done)))
1488
+
1489
+ def body_fn(*args):
1490
+ i, key, done, _, w = args[0]
1491
+ uni_ukey, uni_vkey, key = jr.split(key, 3)
1492
+ u = jr.uniform(
1493
+ key=uni_ukey,
1494
+ shape=shape,
1495
+ dtype=concentration.dtype,
1496
+ minval=-1.0,
1497
+ maxval=1.0,
1498
+ )
1499
+ z = jnp.cos(jnp.pi * u)
1500
+ w = jnp.where(done, w, (1.0 + s * z) / (s + z)) # Update where not done
1501
+ y = concentration * (s - w)
1502
+ v = jr.uniform(key=uni_vkey, shape=shape, dtype=concentration.dtype)
1503
+ accept = (y * (2.0 - y) >= v) | (jnp.log(y / v) + 1.0 >= y)
1504
+ return i + 1, key, accept | done, u, w
1505
+
1506
+ init_done = jnp.zeros(shape, dtype=bool)
1507
+ init_u = jnp.zeros(shape)
1508
+ init_w = jnp.zeros(shape)
1509
+
1510
+ _, _, done, u, w = lax.while_loop(
1511
+ cond_fun=cond_fn,
1512
+ body_fun=body_fn,
1513
+ init_val=(jnp.array(0), key, init_done, init_u, init_w),
1514
+ )
1515
+
1516
+ return jnp.sign(u) * jnp.arccos(w)
1482
1517
 
1483
1518
 
1484
1519
  def _loc_scale(loc, scale, value):
1485
- if loc is None:
1486
- if scale is None:
1487
- return value
1488
- else:
1489
- return value * scale
1490
- else:
1491
- if scale is None:
1492
- return value + loc
1520
+ if loc is None:
1521
+ if scale is None:
1522
+ return value
1523
+ else:
1524
+ return value * scale
1493
1525
  else:
1494
- return value * scale + loc
1526
+ if scale is None:
1527
+ return value + loc
1528
+ else:
1529
+ return value * scale + loc
1495
1530
 
1496
1531
 
1497
1532
  def _check_py_seq(seq):
1498
- return jnp.asarray(seq) if isinstance(seq, (tuple, list)) else seq
1533
+ return jnp.asarray(seq) if isinstance(seq, (tuple, list)) else seq