nnx-ppo 0.2.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nnx_ppo-0.2.1/CHANGELOG.md +55 -0
- nnx_ppo-0.2.1/LICENSE +28 -0
- nnx_ppo-0.2.1/MANIFEST.in +12 -0
- nnx_ppo-0.2.1/PKG-INFO +136 -0
- nnx_ppo-0.2.1/README.md +91 -0
- nnx_ppo-0.2.1/nnx_ppo/__init__.py +1 -0
- nnx_ppo-0.2.1/nnx_ppo/algorithms/__init__.py +0 -0
- nnx_ppo-0.2.1/nnx_ppo/algorithms/callbacks.py +35 -0
- nnx_ppo-0.2.1/nnx_ppo/algorithms/checkpointing.py +204 -0
- nnx_ppo-0.2.1/nnx_ppo/algorithms/checkpointing_test.py +502 -0
- nnx_ppo-0.2.1/nnx_ppo/algorithms/config.py +127 -0
- nnx_ppo-0.2.1/nnx_ppo/algorithms/distillation.py +603 -0
- nnx_ppo-0.2.1/nnx_ppo/algorithms/distillation_test.py +201 -0
- nnx_ppo-0.2.1/nnx_ppo/algorithms/metrics.py +121 -0
- nnx_ppo-0.2.1/nnx_ppo/algorithms/ppo.py +572 -0
- nnx_ppo-0.2.1/nnx_ppo/algorithms/ppo_test.py +486 -0
- nnx_ppo-0.2.1/nnx_ppo/algorithms/rollout.py +279 -0
- nnx_ppo-0.2.1/nnx_ppo/algorithms/rollout_test.py +222 -0
- nnx_ppo-0.2.1/nnx_ppo/algorithms/types.py +150 -0
- nnx_ppo-0.2.1/nnx_ppo/conftest.py +10 -0
- nnx_ppo-0.2.1/nnx_ppo/jax_dataclass.py +41 -0
- nnx_ppo-0.2.1/nnx_ppo/networks/__init__.py +0 -0
- nnx_ppo-0.2.1/nnx_ppo/networks/adapter.py +133 -0
- nnx_ppo-0.2.1/nnx_ppo/networks/adapter_test.py +112 -0
- nnx_ppo-0.2.1/nnx_ppo/networks/ar1_rollout_test.py +288 -0
- nnx_ppo-0.2.1/nnx_ppo/networks/containers.py +218 -0
- nnx_ppo-0.2.1/nnx_ppo/networks/containers_test.py +127 -0
- nnx_ppo-0.2.1/nnx_ppo/networks/delay.py +95 -0
- nnx_ppo-0.2.1/nnx_ppo/networks/delay_test.py +174 -0
- nnx_ppo-0.2.1/nnx_ppo/networks/factories.py +146 -0
- nnx_ppo-0.2.1/nnx_ppo/networks/factories_test.py +123 -0
- nnx_ppo-0.2.1/nnx_ppo/networks/feedforward.py +51 -0
- nnx_ppo-0.2.1/nnx_ppo/networks/graph/__init__.py +12 -0
- nnx_ppo-0.2.1/nnx_ppo/networks/graph/connection.py +34 -0
- nnx_ppo-0.2.1/nnx_ppo/networks/graph/graph.py +448 -0
- nnx_ppo-0.2.1/nnx_ppo/networks/graph/graph_test.py +271 -0
- nnx_ppo-0.2.1/nnx_ppo/networks/graph/population.py +38 -0
- nnx_ppo-0.2.1/nnx_ppo/networks/normalizer.py +136 -0
- nnx_ppo-0.2.1/nnx_ppo/networks/normalizer_test.py +181 -0
- nnx_ppo-0.2.1/nnx_ppo/networks/recurrent.py +161 -0
- nnx_ppo-0.2.1/nnx_ppo/networks/recurrent_test.py +347 -0
- nnx_ppo-0.2.1/nnx_ppo/networks/sampling_layers.py +147 -0
- nnx_ppo-0.2.1/nnx_ppo/networks/types.py +113 -0
- nnx_ppo-0.2.1/nnx_ppo/networks/utils.py +326 -0
- nnx_ppo-0.2.1/nnx_ppo/networks/utils_test.py +316 -0
- nnx_ppo-0.2.1/nnx_ppo/networks/variational.py +216 -0
- nnx_ppo-0.2.1/nnx_ppo/networks/variational_test.py +178 -0
- nnx_ppo-0.2.1/nnx_ppo/py.typed +0 -0
- nnx_ppo-0.2.1/nnx_ppo/test_dummies/__init__.py +0 -0
- nnx_ppo-0.2.1/nnx_ppo/test_dummies/dict_obs_act_env.py +209 -0
- nnx_ppo-0.2.1/nnx_ppo/test_dummies/dummy_counter.py +75 -0
- nnx_ppo-0.2.1/nnx_ppo/test_dummies/mock_env.py +63 -0
- nnx_ppo-0.2.1/nnx_ppo/test_dummies/move_from_center_env.py +51 -0
- nnx_ppo-0.2.1/nnx_ppo/test_dummies/move_to_center_env.py +50 -0
- nnx_ppo-0.2.1/nnx_ppo/test_dummies/parrot_env.py +43 -0
- nnx_ppo-0.2.1/nnx_ppo/test_dummies/stateful_nets.py +40 -0
- nnx_ppo-0.2.1/nnx_ppo/wrappers/__init__.py +0 -0
- nnx_ppo-0.2.1/nnx_ppo/wrappers/episode_wrapper.py +40 -0
- nnx_ppo-0.2.1/nnx_ppo/wrappers/episode_wrapper_test.py +57 -0
- nnx_ppo-0.2.1/nnx_ppo/wrappers/reward_scaling_wrapper.py +28 -0
- nnx_ppo-0.2.1/nnx_ppo.egg-info/SOURCES.txt +60 -0
- nnx_ppo-0.2.1/pyproject.toml +62 -0
- nnx_ppo-0.2.1/setup.cfg +4 -0
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
# Changelog
|
|
2
|
+
|
|
3
|
+
All notable changes to `nnx-ppo` are recorded here. The format follows
|
|
4
|
+
[Keep a Changelog](https://keepachangelog.com/en/1.1.0/) and the project
|
|
5
|
+
adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
|
6
|
+
|
|
7
|
+
## [0.2.1] — 2026-06-09
|
|
8
|
+
|
|
9
|
+
### Added
|
|
10
|
+
- `LoggingLevel.THROUGHPUT` — emits `throughput/train_sps`,
|
|
11
|
+
`throughput/eval_sps`, and `throughput/video_sps` (env + render),
|
|
12
|
+
with `jax.block_until_ready` barriers so the numbers reflect
|
|
13
|
+
device-side wall-clock rather than JAX dispatch latency. Included in
|
|
14
|
+
`LoggingLevel.ALL`.
|
|
15
|
+
- `losses/clipping_fraction` under `LoggingLevel.ACTOR_EXTRA` — the
|
|
16
|
+
fraction of samples whose likelihood ratio left the PPO clip range
|
|
17
|
+
during the gradient phase. Tree-mapped, so it works with multi-actor
|
|
18
|
+
/ multi-agent loglikelihoods.
|
|
19
|
+
|
|
20
|
+
### Changed
|
|
21
|
+
- `RewardScalingWrapper` no longer depends on `mujoco_playground`; it
|
|
22
|
+
is now typed against the local `RLEnv` / `EnvState` protocols in
|
|
23
|
+
`nnx_ppo.algorithms.types`.
|
|
24
|
+
- The `[playground]` extra has been removed. `playground` and
|
|
25
|
+
`warp-lang` are now part of the `[dev]` and `[examples]` extras.
|
|
26
|
+
- Minimum `flax` is now `0.12.7`.
|
|
27
|
+
- License metadata switched to PEP 639 SPDX form
|
|
28
|
+
(`license = "BSD-3-Clause"` + `license-files = ["LICENSE"]`);
|
|
29
|
+
requires `setuptools>=77.0` at build time.
|
|
30
|
+
|
|
31
|
+
### Fixed
|
|
32
|
+
- `PopulationGraph` no longer exposes its build-time registries as a
|
|
33
|
+
second set of `nnx.Param`s — newer Flax versions reflected through
|
|
34
|
+
the underscore-prefixed dicts, which tripped `nnx.jit`'s
|
|
35
|
+
consistent-aliasing check.
|
|
36
|
+
|
|
37
|
+
### Removed
|
|
38
|
+
- `correlations/action_ll` (under `ACTOR_EXTRA`) — was only emitted
|
|
39
|
+
for 1-D action spaces and never fired in multi-actuator setups.
|
|
40
|
+
|
|
41
|
+
## [0.2.0] — 2026-06-03
|
|
42
|
+
|
|
43
|
+
Initial PyPI release.
|
|
44
|
+
|
|
45
|
+
### Added
|
|
46
|
+
- Stateful-network PPO training loop (`nnx_ppo.algorithms.ppo.train_ppo`).
|
|
47
|
+
- Network containers (`Sequential`, `Parallel`, `Concat`, `Splitter`) and the
|
|
48
|
+
two-port `PPOAdapter`.
|
|
49
|
+
- Built-in layers: `Dense`, `LSTM`, `AR1VariationalBottleneck`, `Normalizer`,
|
|
50
|
+
`Delay`, sampling layers, and graph-population utilities.
|
|
51
|
+
- Rollout machinery with per-environment state reset and `update_statistics`
|
|
52
|
+
hook for stats-bearing modules.
|
|
53
|
+
- Orbax-based checkpointing.
|
|
54
|
+
- Distillation utility (`nnx_ppo.algorithms.distillation`).
|
|
55
|
+
- Documentation site at <https://nnx-ppo.readthedocs.io>.
|
nnx_ppo-0.2.1/LICENSE
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
BSD 3-Clause License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026, Emil Wärnberg
|
|
4
|
+
|
|
5
|
+
Redistribution and use in source and binary forms, with or without
|
|
6
|
+
modification, are permitted provided that the following conditions are met:
|
|
7
|
+
|
|
8
|
+
1. Redistributions of source code must retain the above copyright notice, this
|
|
9
|
+
list of conditions and the following disclaimer.
|
|
10
|
+
|
|
11
|
+
2. Redistributions in binary form must reproduce the above copyright notice,
|
|
12
|
+
this list of conditions and the following disclaimer in the documentation
|
|
13
|
+
and/or other materials provided with the distribution.
|
|
14
|
+
|
|
15
|
+
3. Neither the name of the copyright holder nor the names of its
|
|
16
|
+
contributors may be used to endorse or promote products derived from
|
|
17
|
+
this software without specific prior written permission.
|
|
18
|
+
|
|
19
|
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
20
|
+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
21
|
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
22
|
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
23
|
+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
24
|
+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
25
|
+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
26
|
+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
27
|
+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
28
|
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
include README.md
|
|
2
|
+
include LICENSE
|
|
3
|
+
include CHANGELOG.md
|
|
4
|
+
include pyproject.toml
|
|
5
|
+
include nnx_ppo/py.typed
|
|
6
|
+
recursive-exclude docs/_build *
|
|
7
|
+
recursive-exclude * __pycache__
|
|
8
|
+
recursive-exclude * *.py[cod]
|
|
9
|
+
prune .pytest_cache
|
|
10
|
+
prune wandb
|
|
11
|
+
prune .venv
|
|
12
|
+
prune nnx_ppo.egg-info
|
nnx_ppo-0.2.1/PKG-INFO
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: nnx-ppo
|
|
3
|
+
Version: 0.2.1
|
|
4
|
+
Summary: PPO for stateful/recurrent networks in JAX and flax.nnx.
|
|
5
|
+
Author-email: Emil Wärnberg <ewarnberg@fas.harvard.edu>
|
|
6
|
+
License-Expression: BSD-3-Clause
|
|
7
|
+
Project-URL: Homepage, https://github.com/emiwar/nnx-ppo
|
|
8
|
+
Project-URL: Documentation, https://nnx-ppo.readthedocs.io
|
|
9
|
+
Project-URL: Repository, https://github.com/emiwar/nnx-ppo
|
|
10
|
+
Project-URL: Issues, https://github.com/emiwar/nnx-ppo/issues
|
|
11
|
+
Keywords: reinforcement-learning,ppo,jax,flax,rl,neuroscience
|
|
12
|
+
Classifier: Development Status :: 3 - Alpha
|
|
13
|
+
Classifier: Intended Audience :: Science/Research
|
|
14
|
+
Classifier: Operating System :: OS Independent
|
|
15
|
+
Classifier: Programming Language :: Python :: 3
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
18
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
19
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
20
|
+
Requires-Python: >=3.11
|
|
21
|
+
Description-Content-Type: text/markdown
|
|
22
|
+
License-File: LICENSE
|
|
23
|
+
Requires-Dist: jax
|
|
24
|
+
Requires-Dist: flax>=0.12.7
|
|
25
|
+
Requires-Dist: optax
|
|
26
|
+
Requires-Dist: orbax-checkpoint
|
|
27
|
+
Requires-Dist: jaxtyping
|
|
28
|
+
Requires-Dist: numpy
|
|
29
|
+
Provides-Extra: dev
|
|
30
|
+
Requires-Dist: pytest; extra == "dev"
|
|
31
|
+
Requires-Dist: pyright; extra == "dev"
|
|
32
|
+
Requires-Dist: beartype; extra == "dev"
|
|
33
|
+
Requires-Dist: absl-py; extra == "dev"
|
|
34
|
+
Requires-Dist: playground>=0.2.0; extra == "dev"
|
|
35
|
+
Requires-Dist: warp-lang<1.13; extra == "dev"
|
|
36
|
+
Provides-Extra: examples
|
|
37
|
+
Requires-Dist: brax; extra == "examples"
|
|
38
|
+
Requires-Dist: wandb; extra == "examples"
|
|
39
|
+
Requires-Dist: playground>=0.2.0; extra == "examples"
|
|
40
|
+
Requires-Dist: warp-lang<1.13; extra == "examples"
|
|
41
|
+
Provides-Extra: docs
|
|
42
|
+
Requires-Dist: sphinx>=7.0; extra == "docs"
|
|
43
|
+
Requires-Dist: sphinx-rtd-theme; extra == "docs"
|
|
44
|
+
Dynamic: license-file
|
|
45
|
+
|
|
46
|
+
# nnx-ppo
|
|
47
|
+
|
|
48
|
+
Experimental implementation of [Proximal Policy Optimization](https://en.wikipedia.org/wiki/Proximal_policy_optimization)
|
|
49
|
+
in [JAX](https://github.com/google/jax), with first-class support for
|
|
50
|
+
recurrent/stateful networks. Networks are built with
|
|
51
|
+
[flax.nnx](https://flax.readthedocs.io); environments follow the
|
|
52
|
+
[MuJoCo Playground](https://playground.mujoco.org/) API.
|
|
53
|
+
|
|
54
|
+
> **Status:** experimental — the API may change without notice.
|
|
55
|
+
|
|
56
|
+
## Highlights
|
|
57
|
+
|
|
58
|
+
- **Stateful modules.** Recurrent layers, delayed connections, and noisy /
|
|
59
|
+
variational populations are all first-class citizens. Carry state is
|
|
60
|
+
threaded through rollout collection *and* multi-epoch loss replay, and is
|
|
61
|
+
reset correctly when the environment resets — something other JAX RL
|
|
62
|
+
libraries (e.g. Brax) do not natively support.
|
|
63
|
+
- **PyTree observations.** Observations can be arbitrary nested dicts, which
|
|
64
|
+
makes it easy to route different streams (proprioception, vision,
|
|
65
|
+
imitation targets, …) to different parts of a network.
|
|
66
|
+
- **PyTree actions and rewards.** Actions and rewards are also allowed to be
|
|
67
|
+
PyTrees, which simplifies multi-actuator and multi-agent setups.
|
|
68
|
+
|
|
69
|
+
## Installation
|
|
70
|
+
|
|
71
|
+
```bash
|
|
72
|
+
pip install nnx_ppo
|
|
73
|
+
```
|
|
74
|
+
|
|
75
|
+
`nnx-ppo` installs the CPU build of JAX by default. For a CUDA 12 GPU build:
|
|
76
|
+
|
|
77
|
+
```bash
|
|
78
|
+
pip install nnx_ppo "jax[cuda12]"
|
|
79
|
+
```
|
|
80
|
+
|
|
81
|
+
Optional extras:
|
|
82
|
+
|
|
83
|
+
- `nnx_ppo[examples]` — `brax`, `wandb`, and
|
|
84
|
+
[`playground`](https://pypi.org/project/playground/) (import name
|
|
85
|
+
`mujoco_playground`) for the scripts in [examples/](examples/).
|
|
86
|
+
- `nnx_ppo[dev]` — test-suite dependencies (`pytest`, `pyright`, `beartype`,
|
|
87
|
+
`absl-py`, plus `playground` for the env-driven tests).
|
|
88
|
+
|
|
89
|
+
## Quick example
|
|
90
|
+
|
|
91
|
+
```python
|
|
92
|
+
from flax import nnx
|
|
93
|
+
import mujoco_playground
|
|
94
|
+
|
|
95
|
+
from nnx_ppo.algorithms import ppo
|
|
96
|
+
from nnx_ppo.algorithms.config import TrainConfig, PPOConfig, EvalConfig
|
|
97
|
+
from nnx_ppo.networks.factories import make_mlp_actor_critic
|
|
98
|
+
from nnx_ppo.wrappers import episode_wrapper
|
|
99
|
+
|
|
100
|
+
env = mujoco_playground.registry.load("CartpoleBalance")
|
|
101
|
+
train_env = episode_wrapper.EpisodeWrapper(env, 1000)
|
|
102
|
+
|
|
103
|
+
rngs = nnx.Rngs(0)
|
|
104
|
+
nets = make_mlp_actor_critic(
|
|
105
|
+
env.observation_size,
|
|
106
|
+
env.action_size,
|
|
107
|
+
actor_hidden_sizes=[64] * 4,
|
|
108
|
+
critic_hidden_sizes=[256] * 2,
|
|
109
|
+
rngs=rngs,
|
|
110
|
+
normalize_obs=True,
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
result = ppo.train_ppo(
|
|
114
|
+
train_env,
|
|
115
|
+
nets,
|
|
116
|
+
TrainConfig(
|
|
117
|
+
ppo=PPOConfig(n_envs=1024, rollout_length=30, total_steps=10_000_000),
|
|
118
|
+
eval=EvalConfig(enabled=True, every_steps=500_000, n_envs=64,
|
|
119
|
+
max_episode_length=1000),
|
|
120
|
+
),
|
|
121
|
+
eval_env=env,
|
|
122
|
+
)
|
|
123
|
+
print(f"Final eval reward: {result.eval_history[-1]['episode_reward_mean']}")
|
|
124
|
+
```
|
|
125
|
+
|
|
126
|
+
See [examples/wandb_logging.py](examples/wandb_logging.py) for a complete
|
|
127
|
+
training script with W&B logging and video recording.
|
|
128
|
+
|
|
129
|
+
## Documentation
|
|
130
|
+
|
|
131
|
+
Full documentation, tutorials, and API reference are at
|
|
132
|
+
<https://nnx-ppo.readthedocs.io>.
|
|
133
|
+
|
|
134
|
+
## License
|
|
135
|
+
|
|
136
|
+
BSD 3-Clause — see [LICENSE](LICENSE).
|
nnx_ppo-0.2.1/README.md
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
# nnx-ppo
|
|
2
|
+
|
|
3
|
+
Experimental implementation of [Proximal Policy Optimization](https://en.wikipedia.org/wiki/Proximal_policy_optimization)
|
|
4
|
+
in [JAX](https://github.com/google/jax), with first-class support for
|
|
5
|
+
recurrent/stateful networks. Networks are built with
|
|
6
|
+
[flax.nnx](https://flax.readthedocs.io); environments follow the
|
|
7
|
+
[MuJoCo Playground](https://playground.mujoco.org/) API.
|
|
8
|
+
|
|
9
|
+
> **Status:** experimental — the API may change without notice.
|
|
10
|
+
|
|
11
|
+
## Highlights
|
|
12
|
+
|
|
13
|
+
- **Stateful modules.** Recurrent layers, delayed connections, and noisy /
|
|
14
|
+
variational populations are all first-class citizens. Carry state is
|
|
15
|
+
threaded through rollout collection *and* multi-epoch loss replay, and is
|
|
16
|
+
reset correctly when the environment resets — something other JAX RL
|
|
17
|
+
libraries (e.g. Brax) do not natively support.
|
|
18
|
+
- **PyTree observations.** Observations can be arbitrary nested dicts, which
|
|
19
|
+
makes it easy to route different streams (proprioception, vision,
|
|
20
|
+
imitation targets, …) to different parts of a network.
|
|
21
|
+
- **PyTree actions and rewards.** Actions and rewards are also allowed to be
|
|
22
|
+
PyTrees, which simplifies multi-actuator and multi-agent setups.
|
|
23
|
+
|
|
24
|
+
## Installation
|
|
25
|
+
|
|
26
|
+
```bash
|
|
27
|
+
pip install nnx_ppo
|
|
28
|
+
```
|
|
29
|
+
|
|
30
|
+
`nnx-ppo` installs the CPU build of JAX by default. For a CUDA 12 GPU build:
|
|
31
|
+
|
|
32
|
+
```bash
|
|
33
|
+
pip install nnx_ppo "jax[cuda12]"
|
|
34
|
+
```
|
|
35
|
+
|
|
36
|
+
Optional extras:
|
|
37
|
+
|
|
38
|
+
- `nnx_ppo[examples]` — `brax`, `wandb`, and
|
|
39
|
+
[`playground`](https://pypi.org/project/playground/) (import name
|
|
40
|
+
`mujoco_playground`) for the scripts in [examples/](examples/).
|
|
41
|
+
- `nnx_ppo[dev]` — test-suite dependencies (`pytest`, `pyright`, `beartype`,
|
|
42
|
+
`absl-py`, plus `playground` for the env-driven tests).
|
|
43
|
+
|
|
44
|
+
## Quick example
|
|
45
|
+
|
|
46
|
+
```python
|
|
47
|
+
from flax import nnx
|
|
48
|
+
import mujoco_playground
|
|
49
|
+
|
|
50
|
+
from nnx_ppo.algorithms import ppo
|
|
51
|
+
from nnx_ppo.algorithms.config import TrainConfig, PPOConfig, EvalConfig
|
|
52
|
+
from nnx_ppo.networks.factories import make_mlp_actor_critic
|
|
53
|
+
from nnx_ppo.wrappers import episode_wrapper
|
|
54
|
+
|
|
55
|
+
env = mujoco_playground.registry.load("CartpoleBalance")
|
|
56
|
+
train_env = episode_wrapper.EpisodeWrapper(env, 1000)
|
|
57
|
+
|
|
58
|
+
rngs = nnx.Rngs(0)
|
|
59
|
+
nets = make_mlp_actor_critic(
|
|
60
|
+
env.observation_size,
|
|
61
|
+
env.action_size,
|
|
62
|
+
actor_hidden_sizes=[64] * 4,
|
|
63
|
+
critic_hidden_sizes=[256] * 2,
|
|
64
|
+
rngs=rngs,
|
|
65
|
+
normalize_obs=True,
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
result = ppo.train_ppo(
|
|
69
|
+
train_env,
|
|
70
|
+
nets,
|
|
71
|
+
TrainConfig(
|
|
72
|
+
ppo=PPOConfig(n_envs=1024, rollout_length=30, total_steps=10_000_000),
|
|
73
|
+
eval=EvalConfig(enabled=True, every_steps=500_000, n_envs=64,
|
|
74
|
+
max_episode_length=1000),
|
|
75
|
+
),
|
|
76
|
+
eval_env=env,
|
|
77
|
+
)
|
|
78
|
+
print(f"Final eval reward: {result.eval_history[-1]['episode_reward_mean']}")
|
|
79
|
+
```
|
|
80
|
+
|
|
81
|
+
See [examples/wandb_logging.py](examples/wandb_logging.py) for a complete
|
|
82
|
+
training script with W&B logging and video recording.
|
|
83
|
+
|
|
84
|
+
## Documentation
|
|
85
|
+
|
|
86
|
+
Full documentation, tutorials, and API reference are at
|
|
87
|
+
<https://nnx-ppo.readthedocs.io>.
|
|
88
|
+
|
|
89
|
+
## License
|
|
90
|
+
|
|
91
|
+
BSD 3-Clause — see [LICENSE](LICENSE).
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = "0.2.1"
|
|
File without changes
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
"""Callback helpers for training logging."""
|
|
2
|
+
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
|
|
5
|
+
from nnx_ppo.algorithms.config import VideoData
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def wandb_video_fn(
|
|
9
|
+
key: str = "eval_video", fps: int = 30
|
|
10
|
+
) -> Callable[[VideoData], None]:
|
|
11
|
+
"""Create a video callback that logs to wandb.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
key: The wandb log key for the video.
|
|
15
|
+
fps: Frames per second for the video.
|
|
16
|
+
|
|
17
|
+
Returns:
|
|
18
|
+
A callback function compatible with train_ppo's video_fn parameter.
|
|
19
|
+
|
|
20
|
+
Example:
|
|
21
|
+
>>> result = train_ppo(
|
|
22
|
+
... env, networks,
|
|
23
|
+
... video_fn=wandb_video_fn(fps=50),
|
|
24
|
+
... )
|
|
25
|
+
"""
|
|
26
|
+
import wandb
|
|
27
|
+
|
|
28
|
+
def video_fn(data: VideoData) -> None:
|
|
29
|
+
# Convert THWC to TCHW for wandb
|
|
30
|
+
video_array = data.frames.transpose(0, 3, 1, 2)
|
|
31
|
+
wandb.log(
|
|
32
|
+
{key: wandb.Video(video_array, fps=fps, format="mp4")}, step=data.step
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
return video_fn
|
|
@@ -0,0 +1,204 @@
|
|
|
1
|
+
"""Checkpointing utilities for saving and loading training state."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import pickle
|
|
5
|
+
from collections.abc import Callable
|
|
6
|
+
from typing import Any, Optional, Protocol, runtime_checkable
|
|
7
|
+
|
|
8
|
+
import jax
|
|
9
|
+
from flax import nnx
|
|
10
|
+
|
|
11
|
+
from nnx_ppo.algorithms.config import TrainConfig
|
|
12
|
+
from nnx_ppo.algorithms.types import TrainingState
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@runtime_checkable
|
|
16
|
+
class CheckpointCallback(Protocol):
|
|
17
|
+
"""Protocol for checkpoint callbacks with named parameters."""
|
|
18
|
+
|
|
19
|
+
def __call__(self, training_state: TrainingState, step: int) -> None: ...
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _split_net_state(networks):
|
|
23
|
+
"""Split network state: RngKey → pickle, everything else → orbax.
|
|
24
|
+
|
|
25
|
+
orbax cannot handle JAX new-style PRNG key arrays (dtype ``key<fry>``), so
|
|
26
|
+
we separate nnx.RngKey variables and persist them with pickle instead. All
|
|
27
|
+
other variable types — including nnx.Param, nnx.RngCount, and custom
|
|
28
|
+
Variable subclasses such as NormalizerStatistics — are saved via orbax.
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
(non_key_state, rng_key_state, abstract_non_key) — the first two are
|
|
32
|
+
nnx.State objects and the third is the abstract (ShapeDtypeStruct)
|
|
33
|
+
target needed for orbax restoration.
|
|
34
|
+
"""
|
|
35
|
+
_, rng_key_state, non_key_state = nnx.split(networks, nnx.RngKey, ...)
|
|
36
|
+
abstract_non_key = jax.tree_util.tree_map(
|
|
37
|
+
lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype), non_key_state
|
|
38
|
+
)
|
|
39
|
+
return non_key_state, rng_key_state, abstract_non_key
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def make_checkpoint_fn(
|
|
43
|
+
directory: str,
|
|
44
|
+
config: Optional[TrainConfig] = None,
|
|
45
|
+
) -> CheckpointCallback:
|
|
46
|
+
"""Create a checkpoint callback that saves TrainingState to disk.
|
|
47
|
+
|
|
48
|
+
Each checkpoint is written to ``{directory}/step_{step:010d}/``, containing:
|
|
49
|
+
|
|
50
|
+
- ``networks/`` — orbax checkpoint with all non-PRNG-key network variables
|
|
51
|
+
(Param, RngCount, NormalizerStatistics, etc.)
|
|
52
|
+
- ``optimizer/`` — orbax checkpoint with all optimizer state arrays
|
|
53
|
+
- ``metadata.pkl`` — pickle file with network RngKey variables,
|
|
54
|
+
all remaining TrainingState fields (``network_states``, ``env_states``,
|
|
55
|
+
``rng_key``, ``steps_taken``), the step count, and the optional
|
|
56
|
+
TrainConfig.
|
|
57
|
+
|
|
58
|
+
To resume training from a checkpoint, use :func:`load_checkpoint`.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
directory: Base directory under which checkpoint subdirectories are
|
|
62
|
+
created.
|
|
63
|
+
config: Optional TrainConfig to store alongside each checkpoint, useful
|
|
64
|
+
for reproducing training runs.
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
A callback compatible with train_ppo's ``checkpoint_fn`` parameter.
|
|
68
|
+
|
|
69
|
+
Example:
|
|
70
|
+
>>> result = train_ppo(
|
|
71
|
+
... env, networks, config,
|
|
72
|
+
... checkpoint_fn=make_checkpoint_fn("/tmp/my_run", config=config),
|
|
73
|
+
... )
|
|
74
|
+
"""
|
|
75
|
+
|
|
76
|
+
abs_directory = os.path.abspath(directory)
|
|
77
|
+
|
|
78
|
+
def checkpoint_fn(training_state: TrainingState, step: int) -> None:
|
|
79
|
+
import orbax.checkpoint as ocp
|
|
80
|
+
|
|
81
|
+
step_dir = os.path.join(abs_directory, f"step_{step:010d}")
|
|
82
|
+
os.makedirs(step_dir, exist_ok=True)
|
|
83
|
+
|
|
84
|
+
# Split network state: everything except RngKey → orbax, RngKey → pickle.
|
|
85
|
+
# orbax cannot handle JAX new-style PRNG key arrays.
|
|
86
|
+
non_key_state, rng_key_state, _ = _split_net_state(training_state.networks)
|
|
87
|
+
|
|
88
|
+
# The optimizer only contains float/int arrays; no key arrays.
|
|
89
|
+
_, opt_state = nnx.split(training_state.optimizer)
|
|
90
|
+
|
|
91
|
+
# Save parameter arrays with orbax. A fresh checkpointer is created per
|
|
92
|
+
# call and immediately closed to ensure all async writes complete.
|
|
93
|
+
checkpointer = ocp.StandardCheckpointer()
|
|
94
|
+
try:
|
|
95
|
+
checkpointer.save(os.path.join(step_dir, "networks"), non_key_state)
|
|
96
|
+
checkpointer.save(os.path.join(step_dir, "optimizer"), opt_state)
|
|
97
|
+
finally:
|
|
98
|
+
checkpointer.close()
|
|
99
|
+
|
|
100
|
+
# Save everything else with pickle (JAX arrays including PRNG keys are
|
|
101
|
+
# pickle-safe).
|
|
102
|
+
metadata = {
|
|
103
|
+
"networks_rng_key_state": rng_key_state,
|
|
104
|
+
"network_states": training_state.network_states,
|
|
105
|
+
"env_states": training_state.env_states,
|
|
106
|
+
"rng_key": training_state.rng_key,
|
|
107
|
+
"steps_taken": training_state.steps_taken,
|
|
108
|
+
"step": step,
|
|
109
|
+
"config": config,
|
|
110
|
+
}
|
|
111
|
+
with open(os.path.join(step_dir, "metadata.pkl"), "wb") as f:
|
|
112
|
+
pickle.dump(metadata, f)
|
|
113
|
+
|
|
114
|
+
return checkpoint_fn
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def load_checkpoint(
|
|
118
|
+
path: str,
|
|
119
|
+
networks: Any,
|
|
120
|
+
optimizer: nnx.Optimizer,
|
|
121
|
+
) -> dict[str, Any]:
|
|
122
|
+
"""Load a checkpoint saved by :func:`make_checkpoint_fn`.
|
|
123
|
+
|
|
124
|
+
The ``networks`` and ``optimizer`` arguments serve as structural templates:
|
|
125
|
+
their architecture must match the checkpoint, but their current parameter
|
|
126
|
+
values are irrelevant and will be overwritten in-place by the checkpoint
|
|
127
|
+
values.
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
path: Path to the step checkpoint directory, e.g.
|
|
131
|
+
``/tmp/my_run/step_0000500000``.
|
|
132
|
+
networks: Network instance with the same architecture as the checkpoint.
|
|
133
|
+
Weights are updated in-place.
|
|
134
|
+
optimizer: Optimizer instance with the same structure as the checkpoint.
|
|
135
|
+
State is updated in-place.
|
|
136
|
+
|
|
137
|
+
Returns:
|
|
138
|
+
A dict with the following keys:
|
|
139
|
+
|
|
140
|
+
- ``"training_state"`` — restored :class:`TrainingState`
|
|
141
|
+
- ``"step"`` — training step at which the checkpoint was saved (int)
|
|
142
|
+
- ``"config"`` — :class:`TrainConfig` if one was stored, else ``None``
|
|
143
|
+
|
|
144
|
+
Example:
|
|
145
|
+
>>> networks = factories.make_mlp_actor_critic(...)
|
|
146
|
+
>>> training_state = ppo.new_training_state(env, networks, n_envs, seed)
|
|
147
|
+
>>> ckpt = load_checkpoint(
|
|
148
|
+
... "/tmp/my_run/step_0000500000",
|
|
149
|
+
... training_state.networks,
|
|
150
|
+
... training_state.optimizer,
|
|
151
|
+
... )
|
|
152
|
+
>>> result = train_ppo(
|
|
153
|
+
... env, networks, ckpt["config"],
|
|
154
|
+
... initial_state=ckpt["training_state"],
|
|
155
|
+
... )
|
|
156
|
+
"""
|
|
157
|
+
import orbax.checkpoint as ocp
|
|
158
|
+
|
|
159
|
+
path = os.path.abspath(path)
|
|
160
|
+
|
|
161
|
+
# Build abstract targets from the user-provided templates.
|
|
162
|
+
# Use ... to capture remaining variables (RngKey) that we restore via pickle.
|
|
163
|
+
_, _, abstract_non_key = nnx.split(networks, nnx.RngKey, ...)
|
|
164
|
+
abstract_non_key = jax.tree_util.tree_map(
|
|
165
|
+
lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype), abstract_non_key
|
|
166
|
+
)
|
|
167
|
+
_, opt_template = nnx.split(optimizer)
|
|
168
|
+
opt_abstract = jax.tree_util.tree_map(
|
|
169
|
+
lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype), opt_template
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
checkpointer = ocp.StandardCheckpointer()
|
|
173
|
+
try:
|
|
174
|
+
restored_non_key = checkpointer.restore(
|
|
175
|
+
os.path.join(path, "networks"), abstract_non_key
|
|
176
|
+
)
|
|
177
|
+
restored_opt = checkpointer.restore(
|
|
178
|
+
os.path.join(path, "optimizer"), opt_abstract
|
|
179
|
+
)
|
|
180
|
+
finally:
|
|
181
|
+
checkpointer.close()
|
|
182
|
+
|
|
183
|
+
with open(os.path.join(path, "metadata.pkl"), "rb") as f:
|
|
184
|
+
metadata = pickle.load(f)
|
|
185
|
+
|
|
186
|
+
# Merge orbax-restored non-key state with pickled rng-key state,
|
|
187
|
+
# then update the provided modules in-place.
|
|
188
|
+
full_net_state = nnx.merge_state(restored_non_key, metadata["networks_rng_key_state"])
|
|
189
|
+
nnx.update(networks, full_net_state)
|
|
190
|
+
nnx.update(optimizer, restored_opt)
|
|
191
|
+
|
|
192
|
+
training_state = TrainingState(
|
|
193
|
+
networks=networks,
|
|
194
|
+
network_states=metadata["network_states"],
|
|
195
|
+
env_states=metadata["env_states"],
|
|
196
|
+
optimizer=optimizer,
|
|
197
|
+
rng_key=metadata["rng_key"],
|
|
198
|
+
steps_taken=metadata["steps_taken"],
|
|
199
|
+
)
|
|
200
|
+
return {
|
|
201
|
+
"training_state": training_state,
|
|
202
|
+
"step": metadata["step"],
|
|
203
|
+
"config": metadata["config"],
|
|
204
|
+
}
|