orng 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orng/__init__.py +23 -0
- orng/_utils.py +29 -0
- orng/backends/__init__.py +101 -0
- orng/backends/cupy.py +244 -0
- orng/backends/jax.py +256 -0
- orng/backends/numpy.py +209 -0
- orng/backends/torch.py +380 -0
- orng/functional.py +222 -0
- orng/orng.py +222 -0
- orng-0.1.0.dist-info/METADATA +203 -0
- orng-0.1.0.dist-info/RECORD +14 -0
- orng-0.1.0.dist-info/WHEEL +5 -0
- orng-0.1.0.dist-info/licenses/LICENSE +21 -0
- orng-0.1.0.dist-info/top_level.txt +1 -0
orng/__init__.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
import importlib.metadata
|
|
2
|
+
|
|
3
|
+
try:
|
|
4
|
+
__version__ = importlib.metadata.version("orng")
|
|
5
|
+
except importlib.metadata.PackageNotFoundError:
|
|
6
|
+
__version__ = "unknown"
|
|
7
|
+
|
|
8
|
+
from .backends import create_backend_from_xp, infer_backend_name_from_xp
|
|
9
|
+
from .functional import (
|
|
10
|
+
create_functional_backend,
|
|
11
|
+
create_functional_backend_from_xp,
|
|
12
|
+
)
|
|
13
|
+
from .orng import ArrayRNG, RandomGenerator
|
|
14
|
+
|
|
15
|
+
__all__ = [
|
|
16
|
+
"ArrayRNG",
|
|
17
|
+
"RandomGenerator",
|
|
18
|
+
"create_backend_from_xp",
|
|
19
|
+
"create_functional_backend",
|
|
20
|
+
"create_functional_backend_from_xp",
|
|
21
|
+
"infer_backend_name_from_xp",
|
|
22
|
+
"__version__",
|
|
23
|
+
]
|
orng/_utils.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
from typing import Sequence, Tuple
|
|
5
|
+
|
|
6
|
+
SizeLike = int | Sequence[int] | None
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def normalize_shape(size: SizeLike) -> Tuple[int, ...]:
|
|
10
|
+
"""Convert a size argument to a canonical ``tuple`` form."""
|
|
11
|
+
if size is None:
|
|
12
|
+
return ()
|
|
13
|
+
if isinstance(size, int):
|
|
14
|
+
if size < 0:
|
|
15
|
+
raise ValueError("size must be non-negative.")
|
|
16
|
+
return (size,)
|
|
17
|
+
shape = tuple(int(dim) for dim in size)
|
|
18
|
+
if any(dim < 0 for dim in shape):
|
|
19
|
+
raise ValueError("size entries must be non-negative.")
|
|
20
|
+
return shape
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def total_size(shape: Tuple[int, ...]) -> int:
|
|
24
|
+
if not shape:
|
|
25
|
+
return 1
|
|
26
|
+
return math.prod(shape)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
__all__ = ["SizeLike", "normalize_shape", "total_size"]
|
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any, Callable, Dict
|
|
4
|
+
|
|
5
|
+
from .cupy import CuPyBackend
|
|
6
|
+
from .jax import JAXBackend
|
|
7
|
+
from .numpy import NumPyBackend
|
|
8
|
+
from .torch import TorchBackend
|
|
9
|
+
|
|
10
|
+
BackendFactory = Callable[..., Any]
|
|
11
|
+
|
|
12
|
+
_FACTORIES: Dict[str, BackendFactory] = {
|
|
13
|
+
"numpy": lambda *, seed=None, generator=None, device=None: NumPyBackend(
|
|
14
|
+
seed=seed,
|
|
15
|
+
generator=generator,
|
|
16
|
+
),
|
|
17
|
+
"cupy": lambda *, seed=None, generator=None, device=None: CuPyBackend(
|
|
18
|
+
seed=seed,
|
|
19
|
+
generator=generator,
|
|
20
|
+
),
|
|
21
|
+
"torch": lambda *, seed=None, generator=None, device=None: TorchBackend(
|
|
22
|
+
seed=seed,
|
|
23
|
+
generator=generator,
|
|
24
|
+
device=device,
|
|
25
|
+
),
|
|
26
|
+
"pytorch": lambda *, seed=None, generator=None, device=None: TorchBackend(
|
|
27
|
+
seed=seed,
|
|
28
|
+
generator=generator,
|
|
29
|
+
device=device,
|
|
30
|
+
),
|
|
31
|
+
"jax": lambda *, seed=None, generator=None, device=None: JAXBackend(
|
|
32
|
+
seed=seed,
|
|
33
|
+
key=generator,
|
|
34
|
+
),
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def infer_backend_name_from_xp(xp: Any) -> str:
|
|
39
|
+
try:
|
|
40
|
+
from array_api_compat import (
|
|
41
|
+
is_cupy_namespace,
|
|
42
|
+
is_jax_namespace,
|
|
43
|
+
is_numpy_namespace,
|
|
44
|
+
is_torch_namespace,
|
|
45
|
+
)
|
|
46
|
+
except ImportError as exc: # pragma: no cover - optional dependency
|
|
47
|
+
raise ImportError(
|
|
48
|
+
"Inferring ORNG backends from an array namespace requires "
|
|
49
|
+
"'array_api_compat'."
|
|
50
|
+
) from exc
|
|
51
|
+
|
|
52
|
+
if is_numpy_namespace(xp):
|
|
53
|
+
return "numpy"
|
|
54
|
+
if is_jax_namespace(xp):
|
|
55
|
+
return "jax"
|
|
56
|
+
if is_torch_namespace(xp):
|
|
57
|
+
return "torch"
|
|
58
|
+
if is_cupy_namespace(xp):
|
|
59
|
+
return "cupy"
|
|
60
|
+
raise ValueError("Unsupported array namespace for ORNG backend inference.")
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def create_backend(
|
|
64
|
+
name: str,
|
|
65
|
+
*,
|
|
66
|
+
seed: int | None,
|
|
67
|
+
generator: Any | None,
|
|
68
|
+
device: Any | None,
|
|
69
|
+
):
|
|
70
|
+
try:
|
|
71
|
+
factory = _FACTORIES[name.lower()]
|
|
72
|
+
except KeyError as exc: # pragma: no cover - defensive
|
|
73
|
+
supported = "', '".join(
|
|
74
|
+
sorted({k for k in _FACTORIES if k != "pytorch"})
|
|
75
|
+
)
|
|
76
|
+
raise ValueError(
|
|
77
|
+
f"Unsupported backend '{name}'. Expected one of '{supported}'."
|
|
78
|
+
) from exc
|
|
79
|
+
return factory(seed=seed, generator=generator, device=device)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def create_backend_from_xp(
|
|
83
|
+
xp: Any,
|
|
84
|
+
*,
|
|
85
|
+
seed: int | None,
|
|
86
|
+
generator: Any | None,
|
|
87
|
+
device: Any | None,
|
|
88
|
+
):
|
|
89
|
+
return create_backend(
|
|
90
|
+
infer_backend_name_from_xp(xp),
|
|
91
|
+
seed=seed,
|
|
92
|
+
generator=generator,
|
|
93
|
+
device=device,
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
__all__ = [
|
|
98
|
+
"create_backend",
|
|
99
|
+
"create_backend_from_xp",
|
|
100
|
+
"infer_backend_name_from_xp",
|
|
101
|
+
]
|
orng/backends/cupy.py
ADDED
|
@@ -0,0 +1,244 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import copy
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from .._utils import SizeLike
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class CuPyBackend:
|
|
10
|
+
def __init__(self, *, seed: int | None, generator: Any | None) -> None:
|
|
11
|
+
self._impl = CuPyFunctionalBackend(pure=False)
|
|
12
|
+
self._state = self._impl.init_state(seed=seed, generator=generator)
|
|
13
|
+
|
|
14
|
+
def random(self, *, size: SizeLike, dtype: Any | None) -> Any:
|
|
15
|
+
result, self._state = self._impl.random(
|
|
16
|
+
self._state,
|
|
17
|
+
size=size,
|
|
18
|
+
dtype=dtype,
|
|
19
|
+
)
|
|
20
|
+
return result
|
|
21
|
+
|
|
22
|
+
def uniform(
|
|
23
|
+
self,
|
|
24
|
+
*,
|
|
25
|
+
low: Any,
|
|
26
|
+
high: Any,
|
|
27
|
+
size: SizeLike,
|
|
28
|
+
dtype: Any | None,
|
|
29
|
+
) -> Any:
|
|
30
|
+
result, self._state = self._impl.uniform(
|
|
31
|
+
self._state,
|
|
32
|
+
low=low,
|
|
33
|
+
high=high,
|
|
34
|
+
size=size,
|
|
35
|
+
dtype=dtype,
|
|
36
|
+
)
|
|
37
|
+
return result
|
|
38
|
+
|
|
39
|
+
def normal(
|
|
40
|
+
self,
|
|
41
|
+
*,
|
|
42
|
+
loc: Any,
|
|
43
|
+
scale: Any,
|
|
44
|
+
size: SizeLike,
|
|
45
|
+
dtype: Any | None,
|
|
46
|
+
) -> Any:
|
|
47
|
+
result, self._state = self._impl.normal(
|
|
48
|
+
self._state,
|
|
49
|
+
loc=loc,
|
|
50
|
+
scale=scale,
|
|
51
|
+
size=size,
|
|
52
|
+
dtype=dtype,
|
|
53
|
+
)
|
|
54
|
+
return result
|
|
55
|
+
|
|
56
|
+
def gamma(
|
|
57
|
+
self,
|
|
58
|
+
*,
|
|
59
|
+
shape: Any,
|
|
60
|
+
scale: Any,
|
|
61
|
+
size: SizeLike,
|
|
62
|
+
dtype: Any | None,
|
|
63
|
+
) -> Any:
|
|
64
|
+
result, self._state = self._impl.gamma(
|
|
65
|
+
self._state,
|
|
66
|
+
shape=shape,
|
|
67
|
+
scale=scale,
|
|
68
|
+
size=size,
|
|
69
|
+
dtype=dtype,
|
|
70
|
+
)
|
|
71
|
+
return result
|
|
72
|
+
|
|
73
|
+
def choice(
|
|
74
|
+
self,
|
|
75
|
+
population: int | Any,
|
|
76
|
+
*,
|
|
77
|
+
size: SizeLike,
|
|
78
|
+
replace: bool,
|
|
79
|
+
probabilities: Any | None,
|
|
80
|
+
) -> Any:
|
|
81
|
+
result, self._state = self._impl.choice(
|
|
82
|
+
self._state,
|
|
83
|
+
population,
|
|
84
|
+
size=size,
|
|
85
|
+
replace=replace,
|
|
86
|
+
probabilities=probabilities,
|
|
87
|
+
)
|
|
88
|
+
return result
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class CuPyFunctionalBackend:
|
|
92
|
+
def __init__(self, *, pure: bool = True) -> None:
|
|
93
|
+
try:
|
|
94
|
+
import cupy as cp
|
|
95
|
+
except ImportError as exc: # pragma: no cover - optional dependency
|
|
96
|
+
raise ImportError(
|
|
97
|
+
"CuPy backend requires the 'cupy' package to be installed. "
|
|
98
|
+
"Install it with `pip install orng[cupy]`."
|
|
99
|
+
) from exc
|
|
100
|
+
self._cupy = cp
|
|
101
|
+
self._pure = pure
|
|
102
|
+
|
|
103
|
+
def init_state(self, *, seed: int | None, generator: Any | None) -> Any:
|
|
104
|
+
cp = self._cupy
|
|
105
|
+
if generator is None:
|
|
106
|
+
gen = cp.random.default_rng(seed)
|
|
107
|
+
elif isinstance(generator, cp.random.Generator):
|
|
108
|
+
gen = generator
|
|
109
|
+
else:
|
|
110
|
+
raise TypeError(
|
|
111
|
+
"generator must be a cupy.random.Generator when using the "
|
|
112
|
+
"CuPy backend."
|
|
113
|
+
)
|
|
114
|
+
if self._pure:
|
|
115
|
+
return copy.deepcopy(gen.bit_generator.state)
|
|
116
|
+
return gen
|
|
117
|
+
|
|
118
|
+
def _generator_from_state(self, state: Any) -> Any:
|
|
119
|
+
if not self._pure:
|
|
120
|
+
if not isinstance(state, self._cupy.random.Generator):
|
|
121
|
+
raise TypeError(
|
|
122
|
+
"state must be a cupy.random.Generator when pure=False."
|
|
123
|
+
)
|
|
124
|
+
return state
|
|
125
|
+
gen = self._cupy.random.default_rng()
|
|
126
|
+
gen.bit_generator.state = copy.deepcopy(state)
|
|
127
|
+
return gen
|
|
128
|
+
|
|
129
|
+
def _result_and_next_state(self, gen: Any, result: Any) -> tuple[Any, Any]:
|
|
130
|
+
if self._pure:
|
|
131
|
+
return result, copy.deepcopy(gen.bit_generator.state)
|
|
132
|
+
return result, gen
|
|
133
|
+
|
|
134
|
+
def random(
|
|
135
|
+
self,
|
|
136
|
+
state: Any,
|
|
137
|
+
*,
|
|
138
|
+
size: SizeLike,
|
|
139
|
+
dtype: Any | None = None,
|
|
140
|
+
) -> tuple[Any, Any]:
|
|
141
|
+
gen = self._generator_from_state(state)
|
|
142
|
+
result = gen.random(size=size, dtype=dtype)
|
|
143
|
+
return self._result_and_next_state(gen, result)
|
|
144
|
+
|
|
145
|
+
def uniform(
|
|
146
|
+
self,
|
|
147
|
+
state: Any,
|
|
148
|
+
*,
|
|
149
|
+
low: Any,
|
|
150
|
+
high: Any,
|
|
151
|
+
size: SizeLike,
|
|
152
|
+
dtype: Any | None = None,
|
|
153
|
+
) -> tuple[Any, Any]:
|
|
154
|
+
gen = self._generator_from_state(state)
|
|
155
|
+
result = gen.uniform(low=low, high=high, size=size, dtype=dtype)
|
|
156
|
+
return self._result_and_next_state(gen, result)
|
|
157
|
+
|
|
158
|
+
def normal(
|
|
159
|
+
self,
|
|
160
|
+
state: Any,
|
|
161
|
+
*,
|
|
162
|
+
loc: Any,
|
|
163
|
+
scale: Any,
|
|
164
|
+
size: SizeLike,
|
|
165
|
+
dtype: Any | None = None,
|
|
166
|
+
) -> tuple[Any, Any]:
|
|
167
|
+
gen = self._generator_from_state(state)
|
|
168
|
+
standard = gen.standard_normal(size=size, dtype=dtype)
|
|
169
|
+
result = loc + standard * scale
|
|
170
|
+
return self._result_and_next_state(gen, result)
|
|
171
|
+
|
|
172
|
+
def gamma(
|
|
173
|
+
self,
|
|
174
|
+
state: Any,
|
|
175
|
+
*,
|
|
176
|
+
shape: Any,
|
|
177
|
+
scale: Any,
|
|
178
|
+
size: SizeLike,
|
|
179
|
+
dtype: Any | None = None,
|
|
180
|
+
) -> tuple[Any, Any]:
|
|
181
|
+
gen = self._generator_from_state(state)
|
|
182
|
+
result = gen.gamma(shape=shape, scale=scale, size=size)
|
|
183
|
+
if dtype is not None:
|
|
184
|
+
result = self._cupy.asarray(result, dtype=dtype)
|
|
185
|
+
return self._result_and_next_state(gen, result)
|
|
186
|
+
|
|
187
|
+
def choice(
|
|
188
|
+
self,
|
|
189
|
+
state: Any,
|
|
190
|
+
population: int | Any,
|
|
191
|
+
*,
|
|
192
|
+
size: SizeLike,
|
|
193
|
+
replace: bool,
|
|
194
|
+
probabilities: Any | None,
|
|
195
|
+
) -> tuple[Any, Any]:
|
|
196
|
+
cp = self._cupy
|
|
197
|
+
gen = self._generator_from_state(state)
|
|
198
|
+
|
|
199
|
+
if isinstance(population, int):
|
|
200
|
+
n = population
|
|
201
|
+
values = None
|
|
202
|
+
else:
|
|
203
|
+
values = cp.asarray(population)
|
|
204
|
+
n = values.shape[0]
|
|
205
|
+
|
|
206
|
+
target_size = size if size is not None else ()
|
|
207
|
+
|
|
208
|
+
probs = None
|
|
209
|
+
if probabilities is not None:
|
|
210
|
+
probs = cp.asarray(probabilities, dtype=cp.float64)
|
|
211
|
+
probs = probs / cp.sum(probs)
|
|
212
|
+
|
|
213
|
+
if replace:
|
|
214
|
+
if probs is None:
|
|
215
|
+
indices = gen.integers(0, n, size=target_size)
|
|
216
|
+
else:
|
|
217
|
+
cdf = cp.cumsum(probs)
|
|
218
|
+
draws = gen.random(size=target_size)
|
|
219
|
+
indices = cp.searchsorted(cdf, draws)
|
|
220
|
+
else:
|
|
221
|
+
flat_k = (
|
|
222
|
+
1
|
|
223
|
+
if target_size == ()
|
|
224
|
+
else int(cp.prod(cp.asarray(target_size)))
|
|
225
|
+
)
|
|
226
|
+
if probs is None:
|
|
227
|
+
indices = gen.permutation(n)[:flat_k]
|
|
228
|
+
else:
|
|
229
|
+
gumbels = -cp.log(-cp.log(gen.random(n)))
|
|
230
|
+
keys = cp.log(probs) + gumbels
|
|
231
|
+
indices = cp.argpartition(keys, -flat_k)[-flat_k:]
|
|
232
|
+
if target_size != ():
|
|
233
|
+
indices = cp.reshape(indices, target_size)
|
|
234
|
+
else:
|
|
235
|
+
indices = indices[0]
|
|
236
|
+
|
|
237
|
+
if values is None:
|
|
238
|
+
result = indices
|
|
239
|
+
else:
|
|
240
|
+
result = values[indices]
|
|
241
|
+
return self._result_and_next_state(gen, result)
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
__all__ = ["CuPyBackend", "CuPyFunctionalBackend"]
|
orng/backends/jax.py
ADDED
|
@@ -0,0 +1,256 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import secrets
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from .._utils import SizeLike, normalize_shape
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class JAXBackend:
|
|
10
|
+
def __init__(self, *, seed: int | None, key: Any | None) -> None:
|
|
11
|
+
self._impl = JAXFunctionalBackend()
|
|
12
|
+
self._state = self._impl.init_state(seed=seed, generator=key)
|
|
13
|
+
|
|
14
|
+
def random(self, *, size: SizeLike, dtype: Any | None) -> Any:
|
|
15
|
+
result, self._state = self._impl.random(
|
|
16
|
+
self._state,
|
|
17
|
+
size=size,
|
|
18
|
+
dtype=dtype,
|
|
19
|
+
)
|
|
20
|
+
return result
|
|
21
|
+
|
|
22
|
+
def uniform(
|
|
23
|
+
self,
|
|
24
|
+
*,
|
|
25
|
+
low: Any,
|
|
26
|
+
high: Any,
|
|
27
|
+
size: SizeLike,
|
|
28
|
+
dtype: Any | None,
|
|
29
|
+
) -> Any:
|
|
30
|
+
result, self._state = self._impl.uniform(
|
|
31
|
+
self._state,
|
|
32
|
+
low=low,
|
|
33
|
+
high=high,
|
|
34
|
+
size=size,
|
|
35
|
+
dtype=dtype,
|
|
36
|
+
)
|
|
37
|
+
return result
|
|
38
|
+
|
|
39
|
+
def normal(
|
|
40
|
+
self,
|
|
41
|
+
*,
|
|
42
|
+
loc: Any,
|
|
43
|
+
scale: Any,
|
|
44
|
+
size: SizeLike,
|
|
45
|
+
dtype: Any | None,
|
|
46
|
+
) -> Any:
|
|
47
|
+
result, self._state = self._impl.normal(
|
|
48
|
+
self._state,
|
|
49
|
+
loc=loc,
|
|
50
|
+
scale=scale,
|
|
51
|
+
size=size,
|
|
52
|
+
dtype=dtype,
|
|
53
|
+
)
|
|
54
|
+
return result
|
|
55
|
+
|
|
56
|
+
def gamma(
|
|
57
|
+
self,
|
|
58
|
+
*,
|
|
59
|
+
shape: Any,
|
|
60
|
+
scale: Any,
|
|
61
|
+
size: SizeLike,
|
|
62
|
+
dtype: Any | None,
|
|
63
|
+
) -> Any:
|
|
64
|
+
result, self._state = self._impl.gamma(
|
|
65
|
+
self._state,
|
|
66
|
+
shape=shape,
|
|
67
|
+
scale=scale,
|
|
68
|
+
size=size,
|
|
69
|
+
dtype=dtype,
|
|
70
|
+
)
|
|
71
|
+
return result
|
|
72
|
+
|
|
73
|
+
def choice(
|
|
74
|
+
self,
|
|
75
|
+
population: int | Any,
|
|
76
|
+
*,
|
|
77
|
+
size: SizeLike,
|
|
78
|
+
replace: bool,
|
|
79
|
+
probabilities: Any | None,
|
|
80
|
+
) -> Any:
|
|
81
|
+
result, self._state = self._impl.choice(
|
|
82
|
+
self._state,
|
|
83
|
+
population,
|
|
84
|
+
size=size,
|
|
85
|
+
replace=replace,
|
|
86
|
+
probabilities=probabilities,
|
|
87
|
+
)
|
|
88
|
+
return result
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class JAXFunctionalBackend:
|
|
92
|
+
def __init__(self) -> None:
|
|
93
|
+
try:
|
|
94
|
+
import jax
|
|
95
|
+
import jax.numpy as jnp
|
|
96
|
+
except ImportError as exc: # pragma: no cover - optional dependency
|
|
97
|
+
raise ImportError(
|
|
98
|
+
"JAX backend requires the 'jax' package to be installed. "
|
|
99
|
+
"Install it with `pip install orng[jax]`."
|
|
100
|
+
) from exc
|
|
101
|
+
|
|
102
|
+
self._jax = jax
|
|
103
|
+
self._jnp = jnp
|
|
104
|
+
|
|
105
|
+
def init_state(self, *, seed: int | None, generator: Any | None) -> Any:
|
|
106
|
+
if generator is not None:
|
|
107
|
+
return generator
|
|
108
|
+
if seed is None:
|
|
109
|
+
seed = secrets.randbits(32)
|
|
110
|
+
return self._jax.random.key(seed)
|
|
111
|
+
|
|
112
|
+
def random(
|
|
113
|
+
self,
|
|
114
|
+
state: Any,
|
|
115
|
+
*,
|
|
116
|
+
size: SizeLike,
|
|
117
|
+
dtype: Any | None = None,
|
|
118
|
+
) -> tuple[Any, Any]:
|
|
119
|
+
key, next_state = self._jax.random.split(state)
|
|
120
|
+
shape = normalize_shape(size)
|
|
121
|
+
sample_dtype = dtype if dtype is not None else self._jnp.float32
|
|
122
|
+
low = self._jnp.array(0.0, dtype=sample_dtype)
|
|
123
|
+
high = self._jnp.array(1.0, dtype=sample_dtype)
|
|
124
|
+
if shape:
|
|
125
|
+
result = self._jax.random.uniform(
|
|
126
|
+
key,
|
|
127
|
+
shape=shape,
|
|
128
|
+
minval=low,
|
|
129
|
+
maxval=high,
|
|
130
|
+
dtype=sample_dtype,
|
|
131
|
+
)
|
|
132
|
+
else:
|
|
133
|
+
result = self._jax.random.uniform(
|
|
134
|
+
key,
|
|
135
|
+
shape=(1,),
|
|
136
|
+
minval=low,
|
|
137
|
+
maxval=high,
|
|
138
|
+
dtype=sample_dtype,
|
|
139
|
+
)[0]
|
|
140
|
+
return result, next_state
|
|
141
|
+
|
|
142
|
+
def uniform(
|
|
143
|
+
self,
|
|
144
|
+
state: Any,
|
|
145
|
+
*,
|
|
146
|
+
low: Any,
|
|
147
|
+
high: Any,
|
|
148
|
+
size: SizeLike,
|
|
149
|
+
dtype: Any | None = None,
|
|
150
|
+
) -> tuple[Any, Any]:
|
|
151
|
+
key, next_state = self._jax.random.split(state)
|
|
152
|
+
shape = normalize_shape(size)
|
|
153
|
+
sample_dtype = dtype if dtype is not None else self._jnp.float32
|
|
154
|
+
low_arr = self._jnp.asarray(low, dtype=sample_dtype)
|
|
155
|
+
high_arr = self._jnp.asarray(high, dtype=sample_dtype)
|
|
156
|
+
if shape:
|
|
157
|
+
result = self._jax.random.uniform(
|
|
158
|
+
key,
|
|
159
|
+
shape=shape,
|
|
160
|
+
minval=low_arr,
|
|
161
|
+
maxval=high_arr,
|
|
162
|
+
dtype=sample_dtype,
|
|
163
|
+
)
|
|
164
|
+
else:
|
|
165
|
+
result = self._jax.random.uniform(
|
|
166
|
+
key,
|
|
167
|
+
shape=(1,),
|
|
168
|
+
minval=low_arr,
|
|
169
|
+
maxval=high_arr,
|
|
170
|
+
dtype=sample_dtype,
|
|
171
|
+
)[0]
|
|
172
|
+
return result, next_state
|
|
173
|
+
|
|
174
|
+
def normal(
|
|
175
|
+
self,
|
|
176
|
+
state: Any,
|
|
177
|
+
*,
|
|
178
|
+
loc: Any,
|
|
179
|
+
scale: Any,
|
|
180
|
+
size: SizeLike,
|
|
181
|
+
dtype: Any | None = None,
|
|
182
|
+
) -> tuple[Any, Any]:
|
|
183
|
+
key, next_state = self._jax.random.split(state)
|
|
184
|
+
shape = normalize_shape(size)
|
|
185
|
+
sample_dtype = dtype if dtype is not None else self._jnp.float32
|
|
186
|
+
if shape:
|
|
187
|
+
standard = self._jax.random.normal(
|
|
188
|
+
key,
|
|
189
|
+
shape=shape,
|
|
190
|
+
dtype=sample_dtype,
|
|
191
|
+
)
|
|
192
|
+
else:
|
|
193
|
+
standard = self._jax.random.normal(
|
|
194
|
+
key,
|
|
195
|
+
shape=(1,),
|
|
196
|
+
dtype=sample_dtype,
|
|
197
|
+
)[0]
|
|
198
|
+
return standard * scale + loc, next_state
|
|
199
|
+
|
|
200
|
+
def gamma(
|
|
201
|
+
self,
|
|
202
|
+
state: Any,
|
|
203
|
+
*,
|
|
204
|
+
shape: Any,
|
|
205
|
+
scale: Any,
|
|
206
|
+
size: SizeLike,
|
|
207
|
+
dtype: Any | None = None,
|
|
208
|
+
) -> tuple[Any, Any]:
|
|
209
|
+
key, next_state = self._jax.random.split(state)
|
|
210
|
+
sample_shape = normalize_shape(size)
|
|
211
|
+
sample_dtype = dtype if dtype is not None else self._jnp.float32
|
|
212
|
+
concentration = self._jnp.asarray(shape, dtype=sample_dtype)
|
|
213
|
+
scale_arr = self._jnp.asarray(scale, dtype=sample_dtype)
|
|
214
|
+
if sample_shape:
|
|
215
|
+
draw_shape = sample_shape + self._jnp.shape(concentration)
|
|
216
|
+
else:
|
|
217
|
+
draw_shape = self._jnp.shape(concentration)
|
|
218
|
+
gamma_samples = self._jax.random.gamma(
|
|
219
|
+
key,
|
|
220
|
+
concentration,
|
|
221
|
+
shape=draw_shape,
|
|
222
|
+
dtype=sample_dtype,
|
|
223
|
+
)
|
|
224
|
+
scaled = gamma_samples * scale_arr
|
|
225
|
+
return scaled, next_state
|
|
226
|
+
|
|
227
|
+
def choice(
|
|
228
|
+
self,
|
|
229
|
+
state: Any,
|
|
230
|
+
population: int | Any,
|
|
231
|
+
*,
|
|
232
|
+
size: SizeLike,
|
|
233
|
+
replace: bool,
|
|
234
|
+
probabilities: Any | None,
|
|
235
|
+
) -> tuple[Any, Any]:
|
|
236
|
+
key, next_state = self._jax.random.split(state)
|
|
237
|
+
shape = normalize_shape(size)
|
|
238
|
+
jax_shape = shape if shape else None
|
|
239
|
+
if isinstance(population, int):
|
|
240
|
+
domain = population
|
|
241
|
+
else:
|
|
242
|
+
domain = self._jnp.asarray(population)
|
|
243
|
+
probs = (
|
|
244
|
+
None if probabilities is None else self._jnp.asarray(probabilities)
|
|
245
|
+
)
|
|
246
|
+
result = self._jax.random.choice(
|
|
247
|
+
key,
|
|
248
|
+
domain,
|
|
249
|
+
shape=jax_shape,
|
|
250
|
+
replace=replace,
|
|
251
|
+
p=probs,
|
|
252
|
+
)
|
|
253
|
+
return result, next_state
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
__all__ = ["JAXBackend", "JAXFunctionalBackend"]
|