orng 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
orng/__init__.py ADDED
@@ -0,0 +1,23 @@
1
+ import importlib.metadata
2
+
3
+ try:
4
+ __version__ = importlib.metadata.version("orng")
5
+ except importlib.metadata.PackageNotFoundError:
6
+ __version__ = "unknown"
7
+
8
+ from .backends import create_backend_from_xp, infer_backend_name_from_xp
9
+ from .functional import (
10
+ create_functional_backend,
11
+ create_functional_backend_from_xp,
12
+ )
13
+ from .orng import ArrayRNG, RandomGenerator
14
+
15
+ __all__ = [
16
+ "ArrayRNG",
17
+ "RandomGenerator",
18
+ "create_backend_from_xp",
19
+ "create_functional_backend",
20
+ "create_functional_backend_from_xp",
21
+ "infer_backend_name_from_xp",
22
+ "__version__",
23
+ ]
orng/_utils.py ADDED
@@ -0,0 +1,29 @@
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ from typing import Sequence, Tuple
5
+
6
+ SizeLike = int | Sequence[int] | None
7
+
8
+
9
+ def normalize_shape(size: SizeLike) -> Tuple[int, ...]:
10
+ """Convert a size argument to a canonical ``tuple`` form."""
11
+ if size is None:
12
+ return ()
13
+ if isinstance(size, int):
14
+ if size < 0:
15
+ raise ValueError("size must be non-negative.")
16
+ return (size,)
17
+ shape = tuple(int(dim) for dim in size)
18
+ if any(dim < 0 for dim in shape):
19
+ raise ValueError("size entries must be non-negative.")
20
+ return shape
21
+
22
+
23
+ def total_size(shape: Tuple[int, ...]) -> int:
24
+ if not shape:
25
+ return 1
26
+ return math.prod(shape)
27
+
28
+
29
+ __all__ = ["SizeLike", "normalize_shape", "total_size"]
@@ -0,0 +1,101 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Callable, Dict
4
+
5
+ from .cupy import CuPyBackend
6
+ from .jax import JAXBackend
7
+ from .numpy import NumPyBackend
8
+ from .torch import TorchBackend
9
+
10
+ BackendFactory = Callable[..., Any]
11
+
12
+ _FACTORIES: Dict[str, BackendFactory] = {
13
+ "numpy": lambda *, seed=None, generator=None, device=None: NumPyBackend(
14
+ seed=seed,
15
+ generator=generator,
16
+ ),
17
+ "cupy": lambda *, seed=None, generator=None, device=None: CuPyBackend(
18
+ seed=seed,
19
+ generator=generator,
20
+ ),
21
+ "torch": lambda *, seed=None, generator=None, device=None: TorchBackend(
22
+ seed=seed,
23
+ generator=generator,
24
+ device=device,
25
+ ),
26
+ "pytorch": lambda *, seed=None, generator=None, device=None: TorchBackend(
27
+ seed=seed,
28
+ generator=generator,
29
+ device=device,
30
+ ),
31
+ "jax": lambda *, seed=None, generator=None, device=None: JAXBackend(
32
+ seed=seed,
33
+ key=generator,
34
+ ),
35
+ }
36
+
37
+
38
+ def infer_backend_name_from_xp(xp: Any) -> str:
39
+ try:
40
+ from array_api_compat import (
41
+ is_cupy_namespace,
42
+ is_jax_namespace,
43
+ is_numpy_namespace,
44
+ is_torch_namespace,
45
+ )
46
+ except ImportError as exc: # pragma: no cover - optional dependency
47
+ raise ImportError(
48
+ "Inferring ORNG backends from an array namespace requires "
49
+ "'array_api_compat'."
50
+ ) from exc
51
+
52
+ if is_numpy_namespace(xp):
53
+ return "numpy"
54
+ if is_jax_namespace(xp):
55
+ return "jax"
56
+ if is_torch_namespace(xp):
57
+ return "torch"
58
+ if is_cupy_namespace(xp):
59
+ return "cupy"
60
+ raise ValueError("Unsupported array namespace for ORNG backend inference.")
61
+
62
+
63
+ def create_backend(
64
+ name: str,
65
+ *,
66
+ seed: int | None,
67
+ generator: Any | None,
68
+ device: Any | None,
69
+ ):
70
+ try:
71
+ factory = _FACTORIES[name.lower()]
72
+ except KeyError as exc: # pragma: no cover - defensive
73
+ supported = "', '".join(
74
+ sorted({k for k in _FACTORIES if k != "pytorch"})
75
+ )
76
+ raise ValueError(
77
+ f"Unsupported backend '{name}'. Expected one of '{supported}'."
78
+ ) from exc
79
+ return factory(seed=seed, generator=generator, device=device)
80
+
81
+
82
+ def create_backend_from_xp(
83
+ xp: Any,
84
+ *,
85
+ seed: int | None,
86
+ generator: Any | None,
87
+ device: Any | None,
88
+ ):
89
+ return create_backend(
90
+ infer_backend_name_from_xp(xp),
91
+ seed=seed,
92
+ generator=generator,
93
+ device=device,
94
+ )
95
+
96
+
97
+ __all__ = [
98
+ "create_backend",
99
+ "create_backend_from_xp",
100
+ "infer_backend_name_from_xp",
101
+ ]
orng/backends/cupy.py ADDED
@@ -0,0 +1,244 @@
1
+ from __future__ import annotations
2
+
3
+ import copy
4
+ from typing import Any
5
+
6
+ from .._utils import SizeLike
7
+
8
+
9
+ class CuPyBackend:
10
+ def __init__(self, *, seed: int | None, generator: Any | None) -> None:
11
+ self._impl = CuPyFunctionalBackend(pure=False)
12
+ self._state = self._impl.init_state(seed=seed, generator=generator)
13
+
14
+ def random(self, *, size: SizeLike, dtype: Any | None) -> Any:
15
+ result, self._state = self._impl.random(
16
+ self._state,
17
+ size=size,
18
+ dtype=dtype,
19
+ )
20
+ return result
21
+
22
+ def uniform(
23
+ self,
24
+ *,
25
+ low: Any,
26
+ high: Any,
27
+ size: SizeLike,
28
+ dtype: Any | None,
29
+ ) -> Any:
30
+ result, self._state = self._impl.uniform(
31
+ self._state,
32
+ low=low,
33
+ high=high,
34
+ size=size,
35
+ dtype=dtype,
36
+ )
37
+ return result
38
+
39
+ def normal(
40
+ self,
41
+ *,
42
+ loc: Any,
43
+ scale: Any,
44
+ size: SizeLike,
45
+ dtype: Any | None,
46
+ ) -> Any:
47
+ result, self._state = self._impl.normal(
48
+ self._state,
49
+ loc=loc,
50
+ scale=scale,
51
+ size=size,
52
+ dtype=dtype,
53
+ )
54
+ return result
55
+
56
+ def gamma(
57
+ self,
58
+ *,
59
+ shape: Any,
60
+ scale: Any,
61
+ size: SizeLike,
62
+ dtype: Any | None,
63
+ ) -> Any:
64
+ result, self._state = self._impl.gamma(
65
+ self._state,
66
+ shape=shape,
67
+ scale=scale,
68
+ size=size,
69
+ dtype=dtype,
70
+ )
71
+ return result
72
+
73
+ def choice(
74
+ self,
75
+ population: int | Any,
76
+ *,
77
+ size: SizeLike,
78
+ replace: bool,
79
+ probabilities: Any | None,
80
+ ) -> Any:
81
+ result, self._state = self._impl.choice(
82
+ self._state,
83
+ population,
84
+ size=size,
85
+ replace=replace,
86
+ probabilities=probabilities,
87
+ )
88
+ return result
89
+
90
+
91
+ class CuPyFunctionalBackend:
92
+ def __init__(self, *, pure: bool = True) -> None:
93
+ try:
94
+ import cupy as cp
95
+ except ImportError as exc: # pragma: no cover - optional dependency
96
+ raise ImportError(
97
+ "CuPy backend requires the 'cupy' package to be installed. "
98
+ "Install it with `pip install orng[cupy]`."
99
+ ) from exc
100
+ self._cupy = cp
101
+ self._pure = pure
102
+
103
+ def init_state(self, *, seed: int | None, generator: Any | None) -> Any:
104
+ cp = self._cupy
105
+ if generator is None:
106
+ gen = cp.random.default_rng(seed)
107
+ elif isinstance(generator, cp.random.Generator):
108
+ gen = generator
109
+ else:
110
+ raise TypeError(
111
+ "generator must be a cupy.random.Generator when using the "
112
+ "CuPy backend."
113
+ )
114
+ if self._pure:
115
+ return copy.deepcopy(gen.bit_generator.state)
116
+ return gen
117
+
118
+ def _generator_from_state(self, state: Any) -> Any:
119
+ if not self._pure:
120
+ if not isinstance(state, self._cupy.random.Generator):
121
+ raise TypeError(
122
+ "state must be a cupy.random.Generator when pure=False."
123
+ )
124
+ return state
125
+ gen = self._cupy.random.default_rng()
126
+ gen.bit_generator.state = copy.deepcopy(state)
127
+ return gen
128
+
129
+ def _result_and_next_state(self, gen: Any, result: Any) -> tuple[Any, Any]:
130
+ if self._pure:
131
+ return result, copy.deepcopy(gen.bit_generator.state)
132
+ return result, gen
133
+
134
+ def random(
135
+ self,
136
+ state: Any,
137
+ *,
138
+ size: SizeLike,
139
+ dtype: Any | None = None,
140
+ ) -> tuple[Any, Any]:
141
+ gen = self._generator_from_state(state)
142
+ result = gen.random(size=size, dtype=dtype)
143
+ return self._result_and_next_state(gen, result)
144
+
145
+ def uniform(
146
+ self,
147
+ state: Any,
148
+ *,
149
+ low: Any,
150
+ high: Any,
151
+ size: SizeLike,
152
+ dtype: Any | None = None,
153
+ ) -> tuple[Any, Any]:
154
+ gen = self._generator_from_state(state)
155
+ result = gen.uniform(low=low, high=high, size=size, dtype=dtype)
156
+ return self._result_and_next_state(gen, result)
157
+
158
+ def normal(
159
+ self,
160
+ state: Any,
161
+ *,
162
+ loc: Any,
163
+ scale: Any,
164
+ size: SizeLike,
165
+ dtype: Any | None = None,
166
+ ) -> tuple[Any, Any]:
167
+ gen = self._generator_from_state(state)
168
+ standard = gen.standard_normal(size=size, dtype=dtype)
169
+ result = loc + standard * scale
170
+ return self._result_and_next_state(gen, result)
171
+
172
+ def gamma(
173
+ self,
174
+ state: Any,
175
+ *,
176
+ shape: Any,
177
+ scale: Any,
178
+ size: SizeLike,
179
+ dtype: Any | None = None,
180
+ ) -> tuple[Any, Any]:
181
+ gen = self._generator_from_state(state)
182
+ result = gen.gamma(shape=shape, scale=scale, size=size)
183
+ if dtype is not None:
184
+ result = self._cupy.asarray(result, dtype=dtype)
185
+ return self._result_and_next_state(gen, result)
186
+
187
+ def choice(
188
+ self,
189
+ state: Any,
190
+ population: int | Any,
191
+ *,
192
+ size: SizeLike,
193
+ replace: bool,
194
+ probabilities: Any | None,
195
+ ) -> tuple[Any, Any]:
196
+ cp = self._cupy
197
+ gen = self._generator_from_state(state)
198
+
199
+ if isinstance(population, int):
200
+ n = population
201
+ values = None
202
+ else:
203
+ values = cp.asarray(population)
204
+ n = values.shape[0]
205
+
206
+ target_size = size if size is not None else ()
207
+
208
+ probs = None
209
+ if probabilities is not None:
210
+ probs = cp.asarray(probabilities, dtype=cp.float64)
211
+ probs = probs / cp.sum(probs)
212
+
213
+ if replace:
214
+ if probs is None:
215
+ indices = gen.integers(0, n, size=target_size)
216
+ else:
217
+ cdf = cp.cumsum(probs)
218
+ draws = gen.random(size=target_size)
219
+ indices = cp.searchsorted(cdf, draws)
220
+ else:
221
+ flat_k = (
222
+ 1
223
+ if target_size == ()
224
+ else int(cp.prod(cp.asarray(target_size)))
225
+ )
226
+ if probs is None:
227
+ indices = gen.permutation(n)[:flat_k]
228
+ else:
229
+ gumbels = -cp.log(-cp.log(gen.random(n)))
230
+ keys = cp.log(probs) + gumbels
231
+ indices = cp.argpartition(keys, -flat_k)[-flat_k:]
232
+ if target_size != ():
233
+ indices = cp.reshape(indices, target_size)
234
+ else:
235
+ indices = indices[0]
236
+
237
+ if values is None:
238
+ result = indices
239
+ else:
240
+ result = values[indices]
241
+ return self._result_and_next_state(gen, result)
242
+
243
+
244
+ __all__ = ["CuPyBackend", "CuPyFunctionalBackend"]
orng/backends/jax.py ADDED
@@ -0,0 +1,256 @@
1
+ from __future__ import annotations
2
+
3
+ import secrets
4
+ from typing import Any
5
+
6
+ from .._utils import SizeLike, normalize_shape
7
+
8
+
9
+ class JAXBackend:
10
+ def __init__(self, *, seed: int | None, key: Any | None) -> None:
11
+ self._impl = JAXFunctionalBackend()
12
+ self._state = self._impl.init_state(seed=seed, generator=key)
13
+
14
+ def random(self, *, size: SizeLike, dtype: Any | None) -> Any:
15
+ result, self._state = self._impl.random(
16
+ self._state,
17
+ size=size,
18
+ dtype=dtype,
19
+ )
20
+ return result
21
+
22
+ def uniform(
23
+ self,
24
+ *,
25
+ low: Any,
26
+ high: Any,
27
+ size: SizeLike,
28
+ dtype: Any | None,
29
+ ) -> Any:
30
+ result, self._state = self._impl.uniform(
31
+ self._state,
32
+ low=low,
33
+ high=high,
34
+ size=size,
35
+ dtype=dtype,
36
+ )
37
+ return result
38
+
39
+ def normal(
40
+ self,
41
+ *,
42
+ loc: Any,
43
+ scale: Any,
44
+ size: SizeLike,
45
+ dtype: Any | None,
46
+ ) -> Any:
47
+ result, self._state = self._impl.normal(
48
+ self._state,
49
+ loc=loc,
50
+ scale=scale,
51
+ size=size,
52
+ dtype=dtype,
53
+ )
54
+ return result
55
+
56
+ def gamma(
57
+ self,
58
+ *,
59
+ shape: Any,
60
+ scale: Any,
61
+ size: SizeLike,
62
+ dtype: Any | None,
63
+ ) -> Any:
64
+ result, self._state = self._impl.gamma(
65
+ self._state,
66
+ shape=shape,
67
+ scale=scale,
68
+ size=size,
69
+ dtype=dtype,
70
+ )
71
+ return result
72
+
73
+ def choice(
74
+ self,
75
+ population: int | Any,
76
+ *,
77
+ size: SizeLike,
78
+ replace: bool,
79
+ probabilities: Any | None,
80
+ ) -> Any:
81
+ result, self._state = self._impl.choice(
82
+ self._state,
83
+ population,
84
+ size=size,
85
+ replace=replace,
86
+ probabilities=probabilities,
87
+ )
88
+ return result
89
+
90
+
91
+ class JAXFunctionalBackend:
92
+ def __init__(self) -> None:
93
+ try:
94
+ import jax
95
+ import jax.numpy as jnp
96
+ except ImportError as exc: # pragma: no cover - optional dependency
97
+ raise ImportError(
98
+ "JAX backend requires the 'jax' package to be installed. "
99
+ "Install it with `pip install orng[jax]`."
100
+ ) from exc
101
+
102
+ self._jax = jax
103
+ self._jnp = jnp
104
+
105
+ def init_state(self, *, seed: int | None, generator: Any | None) -> Any:
106
+ if generator is not None:
107
+ return generator
108
+ if seed is None:
109
+ seed = secrets.randbits(32)
110
+ return self._jax.random.key(seed)
111
+
112
+ def random(
113
+ self,
114
+ state: Any,
115
+ *,
116
+ size: SizeLike,
117
+ dtype: Any | None = None,
118
+ ) -> tuple[Any, Any]:
119
+ key, next_state = self._jax.random.split(state)
120
+ shape = normalize_shape(size)
121
+ sample_dtype = dtype if dtype is not None else self._jnp.float32
122
+ low = self._jnp.array(0.0, dtype=sample_dtype)
123
+ high = self._jnp.array(1.0, dtype=sample_dtype)
124
+ if shape:
125
+ result = self._jax.random.uniform(
126
+ key,
127
+ shape=shape,
128
+ minval=low,
129
+ maxval=high,
130
+ dtype=sample_dtype,
131
+ )
132
+ else:
133
+ result = self._jax.random.uniform(
134
+ key,
135
+ shape=(1,),
136
+ minval=low,
137
+ maxval=high,
138
+ dtype=sample_dtype,
139
+ )[0]
140
+ return result, next_state
141
+
142
+ def uniform(
143
+ self,
144
+ state: Any,
145
+ *,
146
+ low: Any,
147
+ high: Any,
148
+ size: SizeLike,
149
+ dtype: Any | None = None,
150
+ ) -> tuple[Any, Any]:
151
+ key, next_state = self._jax.random.split(state)
152
+ shape = normalize_shape(size)
153
+ sample_dtype = dtype if dtype is not None else self._jnp.float32
154
+ low_arr = self._jnp.asarray(low, dtype=sample_dtype)
155
+ high_arr = self._jnp.asarray(high, dtype=sample_dtype)
156
+ if shape:
157
+ result = self._jax.random.uniform(
158
+ key,
159
+ shape=shape,
160
+ minval=low_arr,
161
+ maxval=high_arr,
162
+ dtype=sample_dtype,
163
+ )
164
+ else:
165
+ result = self._jax.random.uniform(
166
+ key,
167
+ shape=(1,),
168
+ minval=low_arr,
169
+ maxval=high_arr,
170
+ dtype=sample_dtype,
171
+ )[0]
172
+ return result, next_state
173
+
174
+ def normal(
175
+ self,
176
+ state: Any,
177
+ *,
178
+ loc: Any,
179
+ scale: Any,
180
+ size: SizeLike,
181
+ dtype: Any | None = None,
182
+ ) -> tuple[Any, Any]:
183
+ key, next_state = self._jax.random.split(state)
184
+ shape = normalize_shape(size)
185
+ sample_dtype = dtype if dtype is not None else self._jnp.float32
186
+ if shape:
187
+ standard = self._jax.random.normal(
188
+ key,
189
+ shape=shape,
190
+ dtype=sample_dtype,
191
+ )
192
+ else:
193
+ standard = self._jax.random.normal(
194
+ key,
195
+ shape=(1,),
196
+ dtype=sample_dtype,
197
+ )[0]
198
+ return standard * scale + loc, next_state
199
+
200
+ def gamma(
201
+ self,
202
+ state: Any,
203
+ *,
204
+ shape: Any,
205
+ scale: Any,
206
+ size: SizeLike,
207
+ dtype: Any | None = None,
208
+ ) -> tuple[Any, Any]:
209
+ key, next_state = self._jax.random.split(state)
210
+ sample_shape = normalize_shape(size)
211
+ sample_dtype = dtype if dtype is not None else self._jnp.float32
212
+ concentration = self._jnp.asarray(shape, dtype=sample_dtype)
213
+ scale_arr = self._jnp.asarray(scale, dtype=sample_dtype)
214
+ if sample_shape:
215
+ draw_shape = sample_shape + self._jnp.shape(concentration)
216
+ else:
217
+ draw_shape = self._jnp.shape(concentration)
218
+ gamma_samples = self._jax.random.gamma(
219
+ key,
220
+ concentration,
221
+ shape=draw_shape,
222
+ dtype=sample_dtype,
223
+ )
224
+ scaled = gamma_samples * scale_arr
225
+ return scaled, next_state
226
+
227
+ def choice(
228
+ self,
229
+ state: Any,
230
+ population: int | Any,
231
+ *,
232
+ size: SizeLike,
233
+ replace: bool,
234
+ probabilities: Any | None,
235
+ ) -> tuple[Any, Any]:
236
+ key, next_state = self._jax.random.split(state)
237
+ shape = normalize_shape(size)
238
+ jax_shape = shape if shape else None
239
+ if isinstance(population, int):
240
+ domain = population
241
+ else:
242
+ domain = self._jnp.asarray(population)
243
+ probs = (
244
+ None if probabilities is None else self._jnp.asarray(probabilities)
245
+ )
246
+ result = self._jax.random.choice(
247
+ key,
248
+ domain,
249
+ shape=jax_shape,
250
+ replace=replace,
251
+ p=probs,
252
+ )
253
+ return result, next_state
254
+
255
+
256
+ __all__ = ["JAXBackend", "JAXFunctionalBackend"]