seedall 0.2.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- seedall-0.2.0/PKG-INFO +122 -0
- seedall-0.2.0/README.md +93 -0
- seedall-0.2.0/pyproject.toml +24 -0
- seedall-0.2.0/src/seedall/__init__.py +21 -0
- seedall-0.2.0/src/seedall/core.py +348 -0
- seedall-0.2.0/src/seedall/py.typed +0 -0
seedall-0.2.0/PKG-INFO
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
Metadata-Version: 2.3
|
|
2
|
+
Name: seedall
|
|
3
|
+
Version: 0.2.0
|
|
4
|
+
Summary: Seed all common RNGs in one call for reproducible experiments.
|
|
5
|
+
Keywords: seed,reproducibility,random,numpy,pytorch,tensorflow,jax
|
|
6
|
+
Author: jdh
|
|
7
|
+
Author-email: jdh <you@example.com>
|
|
8
|
+
License: MIT
|
|
9
|
+
Requires-Dist: numpy ; extra == 'all'
|
|
10
|
+
Requires-Dist: torch ; extra == 'all'
|
|
11
|
+
Requires-Dist: tensorflow ; extra == 'all'
|
|
12
|
+
Requires-Dist: jax ; extra == 'all'
|
|
13
|
+
Requires-Dist: jaxlib ; extra == 'all'
|
|
14
|
+
Requires-Dist: pytest ; extra == 'dev'
|
|
15
|
+
Requires-Dist: numpy ; extra == 'dev'
|
|
16
|
+
Requires-Dist: jax ; extra == 'jax'
|
|
17
|
+
Requires-Dist: jaxlib ; extra == 'jax'
|
|
18
|
+
Requires-Dist: numpy ; extra == 'numpy'
|
|
19
|
+
Requires-Dist: tensorflow ; extra == 'tensorflow'
|
|
20
|
+
Requires-Dist: torch ; extra == 'torch'
|
|
21
|
+
Requires-Python: >=3.12
|
|
22
|
+
Provides-Extra: all
|
|
23
|
+
Provides-Extra: dev
|
|
24
|
+
Provides-Extra: jax
|
|
25
|
+
Provides-Extra: numpy
|
|
26
|
+
Provides-Extra: tensorflow
|
|
27
|
+
Provides-Extra: torch
|
|
28
|
+
Description-Content-Type: text/markdown
|
|
29
|
+
|
|
30
|
+
# seedall
|
|
31
|
+
|
|
32
|
+
Seed **all** common RNGs in one call for reproducible experiments.
|
|
33
|
+
|
|
34
|
+
```python
|
|
35
|
+
import seedall
|
|
36
|
+
|
|
37
|
+
seedall.seed(42) # seeds random, numpy, torch, tensorflow, jax, cupy — whatever is installed
|
|
38
|
+
```
|
|
39
|
+
|
|
40
|
+
## Supported backends
|
|
41
|
+
|
|
42
|
+
| Backend | What gets seeded |
|
|
43
|
+
|------------|----------------------------------------------------------|
|
|
44
|
+
| `random` | Python stdlib `random` |
|
|
45
|
+
| `hashseed` | `PYTHONHASHSEED` env var |
|
|
46
|
+
| `numpy` | `np.random.seed()` |
|
|
47
|
+
| `torch` | `torch.manual_seed()` + `cuda.manual_seed_all()` |
|
|
48
|
+
| `tensorflow` | `tf.random.set_seed()` |
|
|
49
|
+
| `jax` | Creates a `jax.random.PRNGKey` (retrieve via states API) |
|
|
50
|
+
| `cupy` | `cp.random.seed()` |
|
|
51
|
+
|
|
52
|
+
Missing libraries are silently skipped — install only what you need.
|
|
53
|
+
|
|
54
|
+
## API
|
|
55
|
+
|
|
56
|
+
### `seedall.seed(value, *, backends=None, deterministic=False, warn_missing=False)`
|
|
57
|
+
|
|
58
|
+
Seed all (or selected) backends. Returns `dict[str, bool]` showing what was seeded.
|
|
59
|
+
|
|
60
|
+
```python
|
|
61
|
+
# Seed everything
|
|
62
|
+
seedall.seed(42)
|
|
63
|
+
|
|
64
|
+
# Seed only specific backends
|
|
65
|
+
seedall.seed(42, backends=["numpy", "torch"])
|
|
66
|
+
|
|
67
|
+
# Also enable PyTorch deterministic mode (slower but fully reproducible)
|
|
68
|
+
seedall.seed(42, deterministic=True)
|
|
69
|
+
```
|
|
70
|
+
|
|
71
|
+
### `seedall.temp_seed(value, *, deterministic=False)`
|
|
72
|
+
|
|
73
|
+
Context manager — seeds on entry, restores previous RNG states on exit.
|
|
74
|
+
|
|
75
|
+
```python
|
|
76
|
+
with seedall.temp_seed(0):
|
|
77
|
+
x = np.random.rand(100) # reproducible
|
|
78
|
+
y = np.random.rand(100) # back to original sequence
|
|
79
|
+
```
|
|
80
|
+
|
|
81
|
+
### `seedall.available()`
|
|
82
|
+
|
|
83
|
+
List detected backends:
|
|
84
|
+
|
|
85
|
+
```python
|
|
86
|
+
>>> seedall.available()
|
|
87
|
+
['random', 'hashseed', 'numpy', 'torch']
|
|
88
|
+
```
|
|
89
|
+
|
|
90
|
+
### `seedall.get_states()` / `seedall.set_states(states)`
|
|
91
|
+
|
|
92
|
+
Snapshot and restore RNG states manually:
|
|
93
|
+
|
|
94
|
+
```python
|
|
95
|
+
states = seedall.get_states()
|
|
96
|
+
# ... do stuff ...
|
|
97
|
+
seedall.set_states(states) # rewind
|
|
98
|
+
```
|
|
99
|
+
|
|
100
|
+
### `seedall.SeedContext(value)`
|
|
101
|
+
|
|
102
|
+
Class-based alternative when a context manager isn't convenient:
|
|
103
|
+
|
|
104
|
+
```python
|
|
105
|
+
ctx = seedall.SeedContext(42)
|
|
106
|
+
ctx.enter() # seed
|
|
107
|
+
# ... run experiment ...
|
|
108
|
+
ctx.exit() # restore
|
|
109
|
+
```
|
|
110
|
+
|
|
111
|
+
## Install
|
|
112
|
+
|
|
113
|
+
```bash
|
|
114
|
+
pip install seedall # core (stdlib random only)
|
|
115
|
+
pip install seedall[numpy] # + numpy
|
|
116
|
+
pip install seedall[torch] # + pytorch
|
|
117
|
+
pip install seedall[all] # everything
|
|
118
|
+
```
|
|
119
|
+
|
|
120
|
+
## License
|
|
121
|
+
|
|
122
|
+
MIT
|
seedall-0.2.0/README.md
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
# seedall
|
|
2
|
+
|
|
3
|
+
Seed **all** common RNGs in one call for reproducible experiments.
|
|
4
|
+
|
|
5
|
+
```python
|
|
6
|
+
import seedall
|
|
7
|
+
|
|
8
|
+
seedall.seed(42) # seeds random, numpy, torch, tensorflow, jax, cupy — whatever is installed
|
|
9
|
+
```
|
|
10
|
+
|
|
11
|
+
## Supported backends
|
|
12
|
+
|
|
13
|
+
| Backend | What gets seeded |
|
|
14
|
+
|------------|----------------------------------------------------------|
|
|
15
|
+
| `random` | Python stdlib `random` |
|
|
16
|
+
| `hashseed` | `PYTHONHASHSEED` env var |
|
|
17
|
+
| `numpy` | `np.random.seed()` |
|
|
18
|
+
| `torch` | `torch.manual_seed()` + `cuda.manual_seed_all()` |
|
|
19
|
+
| `tensorflow` | `tf.random.set_seed()` |
|
|
20
|
+
| `jax` | Creates a `jax.random.PRNGKey` (retrieve via states API) |
|
|
21
|
+
| `cupy` | `cp.random.seed()` |
|
|
22
|
+
|
|
23
|
+
Missing libraries are silently skipped — install only what you need.
|
|
24
|
+
|
|
25
|
+
## API
|
|
26
|
+
|
|
27
|
+
### `seedall.seed(value, *, backends=None, deterministic=False, warn_missing=False)`
|
|
28
|
+
|
|
29
|
+
Seed all (or selected) backends. Returns `dict[str, bool]` showing what was seeded.
|
|
30
|
+
|
|
31
|
+
```python
|
|
32
|
+
# Seed everything
|
|
33
|
+
seedall.seed(42)
|
|
34
|
+
|
|
35
|
+
# Seed only specific backends
|
|
36
|
+
seedall.seed(42, backends=["numpy", "torch"])
|
|
37
|
+
|
|
38
|
+
# Also enable PyTorch deterministic mode (slower but fully reproducible)
|
|
39
|
+
seedall.seed(42, deterministic=True)
|
|
40
|
+
```
|
|
41
|
+
|
|
42
|
+
### `seedall.temp_seed(value, *, deterministic=False)`
|
|
43
|
+
|
|
44
|
+
Context manager — seeds on entry, restores previous RNG states on exit.
|
|
45
|
+
|
|
46
|
+
```python
|
|
47
|
+
with seedall.temp_seed(0):
|
|
48
|
+
x = np.random.rand(100) # reproducible
|
|
49
|
+
y = np.random.rand(100) # back to original sequence
|
|
50
|
+
```
|
|
51
|
+
|
|
52
|
+
### `seedall.available()`
|
|
53
|
+
|
|
54
|
+
List detected backends:
|
|
55
|
+
|
|
56
|
+
```python
|
|
57
|
+
>>> seedall.available()
|
|
58
|
+
['random', 'hashseed', 'numpy', 'torch']
|
|
59
|
+
```
|
|
60
|
+
|
|
61
|
+
### `seedall.get_states()` / `seedall.set_states(states)`
|
|
62
|
+
|
|
63
|
+
Snapshot and restore RNG states manually:
|
|
64
|
+
|
|
65
|
+
```python
|
|
66
|
+
states = seedall.get_states()
|
|
67
|
+
# ... do stuff ...
|
|
68
|
+
seedall.set_states(states) # rewind
|
|
69
|
+
```
|
|
70
|
+
|
|
71
|
+
### `seedall.SeedContext(value)`
|
|
72
|
+
|
|
73
|
+
Class-based alternative when a context manager isn't convenient:
|
|
74
|
+
|
|
75
|
+
```python
|
|
76
|
+
ctx = seedall.SeedContext(42)
|
|
77
|
+
ctx.enter() # seed
|
|
78
|
+
# ... run experiment ...
|
|
79
|
+
ctx.exit() # restore
|
|
80
|
+
```
|
|
81
|
+
|
|
82
|
+
## Install
|
|
83
|
+
|
|
84
|
+
```bash
|
|
85
|
+
pip install seedall # core (stdlib random only)
|
|
86
|
+
pip install seedall[numpy] # + numpy
|
|
87
|
+
pip install seedall[torch] # + pytorch
|
|
88
|
+
pip install seedall[all] # everything
|
|
89
|
+
```
|
|
90
|
+
|
|
91
|
+
## License
|
|
92
|
+
|
|
93
|
+
MIT
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "seedall"
|
|
3
|
+
version = "0.2.0"
|
|
4
|
+
license = {text = "MIT"}
|
|
5
|
+
description = "Seed all common RNGs in one call for reproducible experiments."
|
|
6
|
+
readme = "README.md"
|
|
7
|
+
authors = [
|
|
8
|
+
{ name = "jdh", email = "you@example.com" }
|
|
9
|
+
]
|
|
10
|
+
requires-python = ">=3.12"
|
|
11
|
+
dependencies = []
|
|
12
|
+
keywords = ["seed", "reproducibility", "random", "numpy", "pytorch", "tensorflow", "jax"]
|
|
13
|
+
|
|
14
|
+
[build-system]
|
|
15
|
+
requires = ["uv_build>=0.10.4,<0.11.0"]
|
|
16
|
+
build-backend = "uv_build"
|
|
17
|
+
|
|
18
|
+
[project.optional-dependencies]
|
|
19
|
+
numpy = ["numpy"]
|
|
20
|
+
torch = ["torch"]
|
|
21
|
+
tensorflow = ["tensorflow"]
|
|
22
|
+
jax = ["jax", "jaxlib"]
|
|
23
|
+
all = ["numpy", "torch", "tensorflow", "jax", "jaxlib"]
|
|
24
|
+
dev = ["pytest", "numpy"]
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
"""Seed all common RNGs in one call for reproducible experiments."""
|
|
2
|
+
|
|
3
|
+
from .core import (
|
|
4
|
+
Backend,
|
|
5
|
+
SeedContext,
|
|
6
|
+
available,
|
|
7
|
+
get_states,
|
|
8
|
+
seed,
|
|
9
|
+
set_states,
|
|
10
|
+
temp_seed,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
__all__ = [
|
|
14
|
+
"Backend",
|
|
15
|
+
"SeedContext",
|
|
16
|
+
"available",
|
|
17
|
+
"get_states",
|
|
18
|
+
"seed",
|
|
19
|
+
"set_states",
|
|
20
|
+
"temp_seed",
|
|
21
|
+
]
|
|
@@ -0,0 +1,348 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Core seeding logic for seedall.
|
|
3
|
+
|
|
4
|
+
Supports: random, numpy, torch (CPU + CUDA), tensorflow, JAX, cupy.
|
|
5
|
+
Each backend is optional -- missing libraries are silently skipped.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import os
|
|
11
|
+
import random
|
|
12
|
+
import logging
|
|
13
|
+
import threading
|
|
14
|
+
import warnings
|
|
15
|
+
from contextlib import contextmanager
|
|
16
|
+
from dataclasses import dataclass, field
|
|
17
|
+
from typing import Any, Dict, Generator, List, Optional
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
_lock = threading.Lock()
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
# ---------------------------------------------------------------------------
|
|
25
|
+
# Backend registry
|
|
26
|
+
# ---------------------------------------------------------------------------
|
|
27
|
+
|
|
28
|
+
@dataclass
|
|
29
|
+
class Backend:
|
|
30
|
+
"""Describes one RNG backend (e.g. numpy, torch)."""
|
|
31
|
+
name: str
|
|
32
|
+
seed_fn: Any # callable(seed: int) -> None
|
|
33
|
+
get_state_fn: Any # callable() -> state
|
|
34
|
+
set_state_fn: Any # callable(state) -> None
|
|
35
|
+
available: bool = True
|
|
36
|
+
extras: Dict[str, Any] = field(default_factory=dict)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
_BACKENDS: Dict[str, Backend] = {}
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _register_builtin_backends() -> None:
|
|
43
|
+
"""Detect and register all supported RNG backends."""
|
|
44
|
+
|
|
45
|
+
# 1. Python stdlib random
|
|
46
|
+
_BACKENDS["random"] = Backend(
|
|
47
|
+
name="random",
|
|
48
|
+
seed_fn=random.seed,
|
|
49
|
+
get_state_fn=random.getstate,
|
|
50
|
+
set_state_fn=random.setstate,
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
# 2. os.environ PYTHONHASHSEED (best-effort)
|
|
54
|
+
def _hashseed_set_state(st: Any) -> None:
|
|
55
|
+
if st is not None:
|
|
56
|
+
os.environ["PYTHONHASHSEED"] = str(st)
|
|
57
|
+
else:
|
|
58
|
+
os.environ.pop("PYTHONHASHSEED", None)
|
|
59
|
+
|
|
60
|
+
_BACKENDS["hashseed"] = Backend(
|
|
61
|
+
name="hashseed",
|
|
62
|
+
seed_fn=lambda s: os.environ.__setitem__("PYTHONHASHSEED", str(s)),
|
|
63
|
+
get_state_fn=lambda: os.environ.get("PYTHONHASHSEED"),
|
|
64
|
+
set_state_fn=_hashseed_set_state,
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
# 3. NumPy
|
|
68
|
+
try:
|
|
69
|
+
import numpy as np
|
|
70
|
+
|
|
71
|
+
_BACKENDS["numpy"] = Backend(
|
|
72
|
+
name="numpy",
|
|
73
|
+
seed_fn=lambda s: np.random.seed(s),
|
|
74
|
+
get_state_fn=np.random.get_state,
|
|
75
|
+
set_state_fn=np.random.set_state,
|
|
76
|
+
)
|
|
77
|
+
except ImportError:
|
|
78
|
+
logger.debug("numpy not found -- skipping")
|
|
79
|
+
|
|
80
|
+
# 4. PyTorch
|
|
81
|
+
try:
|
|
82
|
+
import torch
|
|
83
|
+
|
|
84
|
+
def _torch_seed(s: int) -> None:
|
|
85
|
+
torch.manual_seed(s)
|
|
86
|
+
if torch.cuda.is_available():
|
|
87
|
+
torch.cuda.manual_seed(s)
|
|
88
|
+
torch.cuda.manual_seed_all(s)
|
|
89
|
+
|
|
90
|
+
def _torch_get_state() -> dict:
|
|
91
|
+
state = {"cpu": torch.random.get_rng_state()}
|
|
92
|
+
if torch.cuda.is_available():
|
|
93
|
+
state["cuda"] = [
|
|
94
|
+
torch.cuda.get_rng_state(i)
|
|
95
|
+
for i in range(torch.cuda.device_count())
|
|
96
|
+
]
|
|
97
|
+
return state
|
|
98
|
+
|
|
99
|
+
def _torch_set_state(state: dict) -> None:
|
|
100
|
+
torch.random.set_rng_state(state["cpu"])
|
|
101
|
+
if "cuda" in state and torch.cuda.is_available():
|
|
102
|
+
for i, s in enumerate(state["cuda"]):
|
|
103
|
+
torch.cuda.set_rng_state(s, i)
|
|
104
|
+
|
|
105
|
+
_BACKENDS["torch"] = Backend(
|
|
106
|
+
name="torch",
|
|
107
|
+
seed_fn=_torch_seed,
|
|
108
|
+
get_state_fn=_torch_get_state,
|
|
109
|
+
set_state_fn=_torch_set_state,
|
|
110
|
+
)
|
|
111
|
+
except ImportError:
|
|
112
|
+
logger.debug("torch not found -- skipping")
|
|
113
|
+
|
|
114
|
+
# 5. TensorFlow
|
|
115
|
+
# NOTE: TensorFlow does not expose a global RNG get/set state API.
|
|
116
|
+
# Seeding works, but get_states()/set_states() are no-ops for this backend.
|
|
117
|
+
try:
|
|
118
|
+
import tensorflow as tf
|
|
119
|
+
|
|
120
|
+
_BACKENDS["tensorflow"] = Backend(
|
|
121
|
+
name="tensorflow",
|
|
122
|
+
seed_fn=lambda s: tf.random.set_seed(s),
|
|
123
|
+
get_state_fn=lambda: None,
|
|
124
|
+
set_state_fn=lambda st: None,
|
|
125
|
+
)
|
|
126
|
+
except ImportError:
|
|
127
|
+
logger.debug("tensorflow not found -- skipping")
|
|
128
|
+
|
|
129
|
+
# 6. JAX
|
|
130
|
+
try:
|
|
131
|
+
import jax
|
|
132
|
+
|
|
133
|
+
# JAX uses explicit PRNG keys rather than global state, so we store
|
|
134
|
+
# a "default key" that users can retrieve via seedall.get_states().
|
|
135
|
+
_jax_state: Dict[str, Any] = {"key": None}
|
|
136
|
+
|
|
137
|
+
def _jax_seed(s: int) -> None:
|
|
138
|
+
_jax_state["key"] = jax.random.PRNGKey(s)
|
|
139
|
+
|
|
140
|
+
_BACKENDS["jax"] = Backend(
|
|
141
|
+
name="jax",
|
|
142
|
+
seed_fn=_jax_seed,
|
|
143
|
+
get_state_fn=lambda: _jax_state.copy(),
|
|
144
|
+
set_state_fn=lambda st: _jax_state.update(st),
|
|
145
|
+
extras={"get_key": lambda: _jax_state["key"]},
|
|
146
|
+
)
|
|
147
|
+
except ImportError:
|
|
148
|
+
logger.debug("jax not found -- skipping")
|
|
149
|
+
|
|
150
|
+
# 7. CuPy
|
|
151
|
+
try:
|
|
152
|
+
import cupy as cp
|
|
153
|
+
|
|
154
|
+
_BACKENDS["cupy"] = Backend(
|
|
155
|
+
name="cupy",
|
|
156
|
+
seed_fn=lambda s: cp.random.seed(s),
|
|
157
|
+
get_state_fn=cp.random.get_random_state,
|
|
158
|
+
set_state_fn=lambda st: cp.random.set_random_state(st),
|
|
159
|
+
)
|
|
160
|
+
except ImportError:
|
|
161
|
+
logger.debug("cupy not found -- skipping")
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
# Run registration on import
|
|
165
|
+
_register_builtin_backends()
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
# ---------------------------------------------------------------------------
|
|
169
|
+
# Public API
|
|
170
|
+
# ---------------------------------------------------------------------------
|
|
171
|
+
|
|
172
|
+
def seed(
|
|
173
|
+
value: int,
|
|
174
|
+
*,
|
|
175
|
+
backends: Optional[List[str]] = None,
|
|
176
|
+
deterministic: bool = False,
|
|
177
|
+
warn_missing: bool = False,
|
|
178
|
+
) -> Dict[str, bool]:
|
|
179
|
+
"""
|
|
180
|
+
Seed all (or selected) RNG backends for reproducibility.
|
|
181
|
+
|
|
182
|
+
Parameters
|
|
183
|
+
----------
|
|
184
|
+
value : int
|
|
185
|
+
The seed value (must be a non-negative integer).
|
|
186
|
+
backends : list[str], optional
|
|
187
|
+
Subset of backend names to seed. ``None`` means all available.
|
|
188
|
+
deterministic : bool
|
|
189
|
+
If True, also enable PyTorch deterministic mode and disable
|
|
190
|
+
cudnn benchmarking for maximum reproducibility (at a speed cost).
|
|
191
|
+
warn_missing : bool
|
|
192
|
+
If True, emit a warning for each requested backend that is not
|
|
193
|
+
installed.
|
|
194
|
+
|
|
195
|
+
Returns
|
|
196
|
+
-------
|
|
197
|
+
dict[str, bool]
|
|
198
|
+
Mapping of backend name -> whether it was successfully seeded.
|
|
199
|
+
|
|
200
|
+
Raises
|
|
201
|
+
------
|
|
202
|
+
TypeError
|
|
203
|
+
If *value* is not an integer.
|
|
204
|
+
ValueError
|
|
205
|
+
If *value* is negative.
|
|
206
|
+
"""
|
|
207
|
+
if not isinstance(value, int):
|
|
208
|
+
raise TypeError(f"seed value must be an int, got {type(value).__name__}")
|
|
209
|
+
if value < 0:
|
|
210
|
+
raise ValueError(f"seed value must be non-negative, got {value}")
|
|
211
|
+
|
|
212
|
+
results: Dict[str, bool] = {}
|
|
213
|
+
targets = backends or list(_BACKENDS.keys())
|
|
214
|
+
|
|
215
|
+
with _lock:
|
|
216
|
+
for name in targets:
|
|
217
|
+
if name not in _BACKENDS:
|
|
218
|
+
if warn_missing:
|
|
219
|
+
warnings.warn(f"seedall: backend '{name}' is not available")
|
|
220
|
+
results[name] = False
|
|
221
|
+
continue
|
|
222
|
+
|
|
223
|
+
backend = _BACKENDS[name]
|
|
224
|
+
try:
|
|
225
|
+
backend.seed_fn(value)
|
|
226
|
+
results[name] = True
|
|
227
|
+
logger.info("Seeded %s with %d", name, value)
|
|
228
|
+
except Exception as exc:
|
|
229
|
+
logger.warning("Failed to seed %s: %s", name, exc)
|
|
230
|
+
results[name] = False
|
|
231
|
+
|
|
232
|
+
# PyTorch deterministic extras
|
|
233
|
+
if deterministic:
|
|
234
|
+
_set_torch_deterministic(True)
|
|
235
|
+
|
|
236
|
+
return results
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
def available() -> List[str]:
|
|
240
|
+
"""Return the names of all detected RNG backends."""
|
|
241
|
+
return list(_BACKENDS.keys())
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
def get_states(backends: Optional[List[str]] = None) -> Dict[str, Any]:
|
|
245
|
+
"""Snapshot the current RNG state for each backend."""
|
|
246
|
+
targets = backends or list(_BACKENDS.keys())
|
|
247
|
+
with _lock:
|
|
248
|
+
return {
|
|
249
|
+
name: _BACKENDS[name].get_state_fn()
|
|
250
|
+
for name in targets
|
|
251
|
+
if name in _BACKENDS
|
|
252
|
+
}
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
def set_states(states: Dict[str, Any]) -> None:
|
|
256
|
+
"""Restore RNG states from a previous ``get_states()`` snapshot."""
|
|
257
|
+
with _lock:
|
|
258
|
+
for name, state in states.items():
|
|
259
|
+
if name in _BACKENDS:
|
|
260
|
+
_BACKENDS[name].set_state_fn(state)
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
@contextmanager
|
|
264
|
+
def temp_seed(
|
|
265
|
+
value: int, *, deterministic: bool = False
|
|
266
|
+
) -> Generator[None, None, None]:
|
|
267
|
+
"""
|
|
268
|
+
Context manager that seeds all RNGs on entry and restores their
|
|
269
|
+
previous states on exit.
|
|
270
|
+
|
|
271
|
+
Example
|
|
272
|
+
-------
|
|
273
|
+
>>> with seedall.temp_seed(0):
|
|
274
|
+
... x = np.random.rand() # reproducible
|
|
275
|
+
>>> y = np.random.rand() # back to original sequence
|
|
276
|
+
"""
|
|
277
|
+
old_states = get_states()
|
|
278
|
+
old_deterministic = _get_torch_deterministic()
|
|
279
|
+
seed(value, deterministic=deterministic)
|
|
280
|
+
try:
|
|
281
|
+
yield
|
|
282
|
+
finally:
|
|
283
|
+
set_states(old_states)
|
|
284
|
+
if deterministic:
|
|
285
|
+
_set_torch_deterministic(old_deterministic)
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
class SeedContext:
|
|
289
|
+
"""
|
|
290
|
+
Reusable seeding context -- call ``.enter()`` / ``.exit()`` manually
|
|
291
|
+
when a context manager isn't convenient (e.g. in test setUp/tearDown),
|
|
292
|
+
or use as a regular ``with`` statement.
|
|
293
|
+
"""
|
|
294
|
+
|
|
295
|
+
def __init__(self, value: int, *, deterministic: bool = False):
|
|
296
|
+
self.value = value
|
|
297
|
+
self.deterministic = deterministic
|
|
298
|
+
self._saved_states: Optional[Dict[str, Any]] = None
|
|
299
|
+
self._saved_det: Optional[bool] = None
|
|
300
|
+
self._active = False
|
|
301
|
+
|
|
302
|
+
def enter(self) -> None:
|
|
303
|
+
if self._active:
|
|
304
|
+
raise RuntimeError("SeedContext.enter() called while already active")
|
|
305
|
+
self._saved_states = get_states()
|
|
306
|
+
self._saved_det = _get_torch_deterministic()
|
|
307
|
+
self._active = True
|
|
308
|
+
seed(self.value, deterministic=self.deterministic)
|
|
309
|
+
|
|
310
|
+
def exit(self) -> None:
|
|
311
|
+
if not self._active:
|
|
312
|
+
raise RuntimeError("SeedContext.exit() called without a matching enter()")
|
|
313
|
+
if self._saved_states is not None:
|
|
314
|
+
set_states(self._saved_states)
|
|
315
|
+
if self.deterministic and self._saved_det is not None:
|
|
316
|
+
_set_torch_deterministic(self._saved_det)
|
|
317
|
+
self._active = False
|
|
318
|
+
|
|
319
|
+
def __enter__(self) -> SeedContext:
|
|
320
|
+
self.enter()
|
|
321
|
+
return self
|
|
322
|
+
|
|
323
|
+
def __exit__(self, *exc: Any) -> None:
|
|
324
|
+
self.exit()
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
# ---------------------------------------------------------------------------
|
|
328
|
+
# Helpers
|
|
329
|
+
# ---------------------------------------------------------------------------
|
|
330
|
+
|
|
331
|
+
def _set_torch_deterministic(enabled: bool) -> None:
|
|
332
|
+
try:
|
|
333
|
+
import torch
|
|
334
|
+
torch.use_deterministic_algorithms(enabled)
|
|
335
|
+
torch.backends.cudnn.deterministic = enabled
|
|
336
|
+
torch.backends.cudnn.benchmark = not enabled
|
|
337
|
+
except ImportError:
|
|
338
|
+
pass
|
|
339
|
+
except Exception as exc:
|
|
340
|
+
logger.debug("Could not set torch deterministic mode: %s", exc)
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
def _get_torch_deterministic() -> bool:
|
|
344
|
+
try:
|
|
345
|
+
import torch
|
|
346
|
+
return torch.are_deterministic_algorithms_enabled()
|
|
347
|
+
except (ImportError, AttributeError):
|
|
348
|
+
return False
|
|
File without changes
|