nnx-ppo 0.2.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. nnx_ppo-0.2.1/CHANGELOG.md +55 -0
  2. nnx_ppo-0.2.1/LICENSE +28 -0
  3. nnx_ppo-0.2.1/MANIFEST.in +12 -0
  4. nnx_ppo-0.2.1/PKG-INFO +136 -0
  5. nnx_ppo-0.2.1/README.md +91 -0
  6. nnx_ppo-0.2.1/nnx_ppo/__init__.py +1 -0
  7. nnx_ppo-0.2.1/nnx_ppo/algorithms/__init__.py +0 -0
  8. nnx_ppo-0.2.1/nnx_ppo/algorithms/callbacks.py +35 -0
  9. nnx_ppo-0.2.1/nnx_ppo/algorithms/checkpointing.py +204 -0
  10. nnx_ppo-0.2.1/nnx_ppo/algorithms/checkpointing_test.py +502 -0
  11. nnx_ppo-0.2.1/nnx_ppo/algorithms/config.py +127 -0
  12. nnx_ppo-0.2.1/nnx_ppo/algorithms/distillation.py +603 -0
  13. nnx_ppo-0.2.1/nnx_ppo/algorithms/distillation_test.py +201 -0
  14. nnx_ppo-0.2.1/nnx_ppo/algorithms/metrics.py +121 -0
  15. nnx_ppo-0.2.1/nnx_ppo/algorithms/ppo.py +572 -0
  16. nnx_ppo-0.2.1/nnx_ppo/algorithms/ppo_test.py +486 -0
  17. nnx_ppo-0.2.1/nnx_ppo/algorithms/rollout.py +279 -0
  18. nnx_ppo-0.2.1/nnx_ppo/algorithms/rollout_test.py +222 -0
  19. nnx_ppo-0.2.1/nnx_ppo/algorithms/types.py +150 -0
  20. nnx_ppo-0.2.1/nnx_ppo/conftest.py +10 -0
  21. nnx_ppo-0.2.1/nnx_ppo/jax_dataclass.py +41 -0
  22. nnx_ppo-0.2.1/nnx_ppo/networks/__init__.py +0 -0
  23. nnx_ppo-0.2.1/nnx_ppo/networks/adapter.py +133 -0
  24. nnx_ppo-0.2.1/nnx_ppo/networks/adapter_test.py +112 -0
  25. nnx_ppo-0.2.1/nnx_ppo/networks/ar1_rollout_test.py +288 -0
  26. nnx_ppo-0.2.1/nnx_ppo/networks/containers.py +218 -0
  27. nnx_ppo-0.2.1/nnx_ppo/networks/containers_test.py +127 -0
  28. nnx_ppo-0.2.1/nnx_ppo/networks/delay.py +95 -0
  29. nnx_ppo-0.2.1/nnx_ppo/networks/delay_test.py +174 -0
  30. nnx_ppo-0.2.1/nnx_ppo/networks/factories.py +146 -0
  31. nnx_ppo-0.2.1/nnx_ppo/networks/factories_test.py +123 -0
  32. nnx_ppo-0.2.1/nnx_ppo/networks/feedforward.py +51 -0
  33. nnx_ppo-0.2.1/nnx_ppo/networks/graph/__init__.py +12 -0
  34. nnx_ppo-0.2.1/nnx_ppo/networks/graph/connection.py +34 -0
  35. nnx_ppo-0.2.1/nnx_ppo/networks/graph/graph.py +448 -0
  36. nnx_ppo-0.2.1/nnx_ppo/networks/graph/graph_test.py +271 -0
  37. nnx_ppo-0.2.1/nnx_ppo/networks/graph/population.py +38 -0
  38. nnx_ppo-0.2.1/nnx_ppo/networks/normalizer.py +136 -0
  39. nnx_ppo-0.2.1/nnx_ppo/networks/normalizer_test.py +181 -0
  40. nnx_ppo-0.2.1/nnx_ppo/networks/recurrent.py +161 -0
  41. nnx_ppo-0.2.1/nnx_ppo/networks/recurrent_test.py +347 -0
  42. nnx_ppo-0.2.1/nnx_ppo/networks/sampling_layers.py +147 -0
  43. nnx_ppo-0.2.1/nnx_ppo/networks/types.py +113 -0
  44. nnx_ppo-0.2.1/nnx_ppo/networks/utils.py +326 -0
  45. nnx_ppo-0.2.1/nnx_ppo/networks/utils_test.py +316 -0
  46. nnx_ppo-0.2.1/nnx_ppo/networks/variational.py +216 -0
  47. nnx_ppo-0.2.1/nnx_ppo/networks/variational_test.py +178 -0
  48. nnx_ppo-0.2.1/nnx_ppo/py.typed +0 -0
  49. nnx_ppo-0.2.1/nnx_ppo/test_dummies/__init__.py +0 -0
  50. nnx_ppo-0.2.1/nnx_ppo/test_dummies/dict_obs_act_env.py +209 -0
  51. nnx_ppo-0.2.1/nnx_ppo/test_dummies/dummy_counter.py +75 -0
  52. nnx_ppo-0.2.1/nnx_ppo/test_dummies/mock_env.py +63 -0
  53. nnx_ppo-0.2.1/nnx_ppo/test_dummies/move_from_center_env.py +51 -0
  54. nnx_ppo-0.2.1/nnx_ppo/test_dummies/move_to_center_env.py +50 -0
  55. nnx_ppo-0.2.1/nnx_ppo/test_dummies/parrot_env.py +43 -0
  56. nnx_ppo-0.2.1/nnx_ppo/test_dummies/stateful_nets.py +40 -0
  57. nnx_ppo-0.2.1/nnx_ppo/wrappers/__init__.py +0 -0
  58. nnx_ppo-0.2.1/nnx_ppo/wrappers/episode_wrapper.py +40 -0
  59. nnx_ppo-0.2.1/nnx_ppo/wrappers/episode_wrapper_test.py +57 -0
  60. nnx_ppo-0.2.1/nnx_ppo/wrappers/reward_scaling_wrapper.py +28 -0
  61. nnx_ppo-0.2.1/nnx_ppo.egg-info/SOURCES.txt +60 -0
  62. nnx_ppo-0.2.1/pyproject.toml +62 -0
  63. nnx_ppo-0.2.1/setup.cfg +4 -0
@@ -0,0 +1,55 @@
1
+ # Changelog
2
+
3
+ All notable changes to `nnx-ppo` are recorded here. The format follows
4
+ [Keep a Changelog](https://keepachangelog.com/en/1.1.0/) and the project
5
+ adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
6
+
7
+ ## [0.2.1] — 2026-06-09
8
+
9
+ ### Added
10
+ - `LoggingLevel.THROUGHPUT` — emits `throughput/train_sps`,
11
+ `throughput/eval_sps`, and `throughput/video_sps` (env + render),
12
+ with `jax.block_until_ready` barriers so the numbers reflect
13
+ device-side wall-clock rather than JAX dispatch latency. Included in
14
+ `LoggingLevel.ALL`.
15
+ - `losses/clipping_fraction` under `LoggingLevel.ACTOR_EXTRA` — the
16
+ fraction of samples whose likelihood ratio left the PPO clip range
17
+ during the gradient phase. Tree-mapped, so it works with multi-actor
18
+ / multi-agent loglikelihoods.
19
+
20
+ ### Changed
21
+ - `RewardScalingWrapper` no longer depends on `mujoco_playground`; it
22
+ is now typed against the local `RLEnv` / `EnvState` protocols in
23
+ `nnx_ppo.algorithms.types`.
24
+ - The `[playground]` extra has been removed. `playground` and
25
+ `warp-lang` are now part of the `[dev]` and `[examples]` extras.
26
+ - Minimum `flax` is now `0.12.7`.
27
+ - License metadata switched to PEP 639 SPDX form
28
+ (`license = "BSD-3-Clause"` + `license-files = ["LICENSE"]`);
29
+ requires `setuptools>=77.0` at build time.
30
+
31
+ ### Fixed
32
+ - `PopulationGraph` no longer exposes its build-time registries as a
33
+ second set of `nnx.Param`s — newer Flax versions reflected through
34
+ the underscore-prefixed dicts, which tripped `nnx.jit`'s
35
+ consistent-aliasing check.
36
+
37
+ ### Removed
38
+ - `correlations/action_ll` (under `ACTOR_EXTRA`) — was only emitted
39
+ for 1-D action spaces and never fired in multi-actuator setups.
40
+
41
+ ## [0.2.0] — 2026-06-03
42
+
43
+ Initial PyPI release.
44
+
45
+ ### Added
46
+ - Stateful-network PPO training loop (`nnx_ppo.algorithms.ppo.train_ppo`).
47
+ - Network containers (`Sequential`, `Parallel`, `Concat`, `Splitter`) and the
48
+ two-port `PPOAdapter`.
49
+ - Built-in layers: `Dense`, `LSTM`, `AR1VariationalBottleneck`, `Normalizer`,
50
+ `Delay`, sampling layers, and graph-population utilities.
51
+ - Rollout machinery with per-environment state reset and `update_statistics`
52
+ hook for stats-bearing modules.
53
+ - Orbax-based checkpointing.
54
+ - Distillation utility (`nnx_ppo.algorithms.distillation`).
55
+ - Documentation site at <https://nnx-ppo.readthedocs.io>.
nnx_ppo-0.2.1/LICENSE ADDED
@@ -0,0 +1,28 @@
1
+ BSD 3-Clause License
2
+
3
+ Copyright (c) 2026, Emil Wärnberg
4
+
5
+ Redistribution and use in source and binary forms, with or without
6
+ modification, are permitted provided that the following conditions are met:
7
+
8
+ 1. Redistributions of source code must retain the above copyright notice, this
9
+ list of conditions and the following disclaimer.
10
+
11
+ 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ this list of conditions and the following disclaimer in the documentation
13
+ and/or other materials provided with the distribution.
14
+
15
+ 3. Neither the name of the copyright holder nor the names of its
16
+ contributors may be used to endorse or promote products derived from
17
+ this software without specific prior written permission.
18
+
19
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
@@ -0,0 +1,12 @@
1
+ include README.md
2
+ include LICENSE
3
+ include CHANGELOG.md
4
+ include pyproject.toml
5
+ include nnx_ppo/py.typed
6
+ recursive-exclude docs/_build *
7
+ recursive-exclude * __pycache__
8
+ recursive-exclude * *.py[cod]
9
+ prune .pytest_cache
10
+ prune wandb
11
+ prune .venv
12
+ prune nnx_ppo.egg-info
nnx_ppo-0.2.1/PKG-INFO ADDED
@@ -0,0 +1,136 @@
1
+ Metadata-Version: 2.4
2
+ Name: nnx-ppo
3
+ Version: 0.2.1
4
+ Summary: PPO for stateful/recurrent networks in JAX and flax.nnx.
5
+ Author-email: Emil Wärnberg <ewarnberg@fas.harvard.edu>
6
+ License-Expression: BSD-3-Clause
7
+ Project-URL: Homepage, https://github.com/emiwar/nnx-ppo
8
+ Project-URL: Documentation, https://nnx-ppo.readthedocs.io
9
+ Project-URL: Repository, https://github.com/emiwar/nnx-ppo
10
+ Project-URL: Issues, https://github.com/emiwar/nnx-ppo/issues
11
+ Keywords: reinforcement-learning,ppo,jax,flax,rl,neuroscience
12
+ Classifier: Development Status :: 3 - Alpha
13
+ Classifier: Intended Audience :: Science/Research
14
+ Classifier: Operating System :: OS Independent
15
+ Classifier: Programming Language :: Python :: 3
16
+ Classifier: Programming Language :: Python :: 3.11
17
+ Classifier: Programming Language :: Python :: 3.12
18
+ Classifier: Programming Language :: Python :: 3.13
19
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
20
+ Requires-Python: >=3.11
21
+ Description-Content-Type: text/markdown
22
+ License-File: LICENSE
23
+ Requires-Dist: jax
24
+ Requires-Dist: flax>=0.12.7
25
+ Requires-Dist: optax
26
+ Requires-Dist: orbax-checkpoint
27
+ Requires-Dist: jaxtyping
28
+ Requires-Dist: numpy
29
+ Provides-Extra: dev
30
+ Requires-Dist: pytest; extra == "dev"
31
+ Requires-Dist: pyright; extra == "dev"
32
+ Requires-Dist: beartype; extra == "dev"
33
+ Requires-Dist: absl-py; extra == "dev"
34
+ Requires-Dist: playground>=0.2.0; extra == "dev"
35
+ Requires-Dist: warp-lang<1.13; extra == "dev"
36
+ Provides-Extra: examples
37
+ Requires-Dist: brax; extra == "examples"
38
+ Requires-Dist: wandb; extra == "examples"
39
+ Requires-Dist: playground>=0.2.0; extra == "examples"
40
+ Requires-Dist: warp-lang<1.13; extra == "examples"
41
+ Provides-Extra: docs
42
+ Requires-Dist: sphinx>=7.0; extra == "docs"
43
+ Requires-Dist: sphinx-rtd-theme; extra == "docs"
44
+ Dynamic: license-file
45
+
46
+ # nnx-ppo
47
+
48
+ Experimental implementation of [Proximal Policy Optimization](https://en.wikipedia.org/wiki/Proximal_policy_optimization)
49
+ in [JAX](https://github.com/google/jax), with first-class support for
50
+ recurrent/stateful networks. Networks are built with
51
+ [flax.nnx](https://flax.readthedocs.io); environments follow the
52
+ [MuJoCo Playground](https://playground.mujoco.org/) API.
53
+
54
+ > **Status:** experimental — the API may change without notice.
55
+
56
+ ## Highlights
57
+
58
+ - **Stateful modules.** Recurrent layers, delayed connections, and noisy /
59
+ variational populations are all first-class citizens. Carry state is
60
+ threaded through rollout collection *and* multi-epoch loss replay, and is
61
+ reset correctly when the environment resets — something other JAX RL
62
+ libraries (e.g. Brax) do not natively support.
63
+ - **PyTree observations.** Observations can be arbitrary nested dicts, which
64
+ makes it easy to route different streams (proprioception, vision,
65
+ imitation targets, …) to different parts of a network.
66
+ - **PyTree actions and rewards.** Actions and rewards are also allowed to be
67
+ PyTrees, which simplifies multi-actuator and multi-agent setups.
68
+
69
+ ## Installation
70
+
71
+ ```bash
72
+ pip install nnx_ppo
73
+ ```
74
+
75
+ `nnx-ppo` installs the CPU build of JAX by default. For a CUDA 12 GPU build:
76
+
77
+ ```bash
78
+ pip install nnx_ppo "jax[cuda12]"
79
+ ```
80
+
81
+ Optional extras:
82
+
83
+ - `nnx_ppo[examples]` — `brax`, `wandb`, and
84
+ [`playground`](https://pypi.org/project/playground/) (import name
85
+ `mujoco_playground`) for the scripts in [examples/](examples/).
86
+ - `nnx_ppo[dev]` — test-suite dependencies (`pytest`, `pyright`, `beartype`,
87
+ `absl-py`, plus `playground` for the env-driven tests).
88
+
89
+ ## Quick example
90
+
91
+ ```python
92
+ from flax import nnx
93
+ import mujoco_playground
94
+
95
+ from nnx_ppo.algorithms import ppo
96
+ from nnx_ppo.algorithms.config import TrainConfig, PPOConfig, EvalConfig
97
+ from nnx_ppo.networks.factories import make_mlp_actor_critic
98
+ from nnx_ppo.wrappers import episode_wrapper
99
+
100
+ env = mujoco_playground.registry.load("CartpoleBalance")
101
+ train_env = episode_wrapper.EpisodeWrapper(env, 1000)
102
+
103
+ rngs = nnx.Rngs(0)
104
+ nets = make_mlp_actor_critic(
105
+ env.observation_size,
106
+ env.action_size,
107
+ actor_hidden_sizes=[64] * 4,
108
+ critic_hidden_sizes=[256] * 2,
109
+ rngs=rngs,
110
+ normalize_obs=True,
111
+ )
112
+
113
+ result = ppo.train_ppo(
114
+ train_env,
115
+ nets,
116
+ TrainConfig(
117
+ ppo=PPOConfig(n_envs=1024, rollout_length=30, total_steps=10_000_000),
118
+ eval=EvalConfig(enabled=True, every_steps=500_000, n_envs=64,
119
+ max_episode_length=1000),
120
+ ),
121
+ eval_env=env,
122
+ )
123
+ print(f"Final eval reward: {result.eval_history[-1]['episode_reward_mean']}")
124
+ ```
125
+
126
+ See [examples/wandb_logging.py](examples/wandb_logging.py) for a complete
127
+ training script with W&B logging and video recording.
128
+
129
+ ## Documentation
130
+
131
+ Full documentation, tutorials, and API reference are at
132
+ <https://nnx-ppo.readthedocs.io>.
133
+
134
+ ## License
135
+
136
+ BSD 3-Clause — see [LICENSE](LICENSE).
@@ -0,0 +1,91 @@
1
+ # nnx-ppo
2
+
3
+ Experimental implementation of [Proximal Policy Optimization](https://en.wikipedia.org/wiki/Proximal_policy_optimization)
4
+ in [JAX](https://github.com/google/jax), with first-class support for
5
+ recurrent/stateful networks. Networks are built with
6
+ [flax.nnx](https://flax.readthedocs.io); environments follow the
7
+ [MuJoCo Playground](https://playground.mujoco.org/) API.
8
+
9
+ > **Status:** experimental — the API may change without notice.
10
+
11
+ ## Highlights
12
+
13
+ - **Stateful modules.** Recurrent layers, delayed connections, and noisy /
14
+ variational populations are all first-class citizens. Carry state is
15
+ threaded through rollout collection *and* multi-epoch loss replay, and is
16
+ reset correctly when the environment resets — something other JAX RL
17
+ libraries (e.g. Brax) do not natively support.
18
+ - **PyTree observations.** Observations can be arbitrary nested dicts, which
19
+ makes it easy to route different streams (proprioception, vision,
20
+ imitation targets, …) to different parts of a network.
21
+ - **PyTree actions and rewards.** Actions and rewards are also allowed to be
22
+ PyTrees, which simplifies multi-actuator and multi-agent setups.
23
+
24
+ ## Installation
25
+
26
+ ```bash
27
+ pip install nnx_ppo
28
+ ```
29
+
30
+ `nnx-ppo` installs the CPU build of JAX by default. For a CUDA 12 GPU build:
31
+
32
+ ```bash
33
+ pip install nnx_ppo "jax[cuda12]"
34
+ ```
35
+
36
+ Optional extras:
37
+
38
+ - `nnx_ppo[examples]` — `brax`, `wandb`, and
39
+ [`playground`](https://pypi.org/project/playground/) (import name
40
+ `mujoco_playground`) for the scripts in [examples/](examples/).
41
+ - `nnx_ppo[dev]` — test-suite dependencies (`pytest`, `pyright`, `beartype`,
42
+ `absl-py`, plus `playground` for the env-driven tests).
43
+
44
+ ## Quick example
45
+
46
+ ```python
47
+ from flax import nnx
48
+ import mujoco_playground
49
+
50
+ from nnx_ppo.algorithms import ppo
51
+ from nnx_ppo.algorithms.config import TrainConfig, PPOConfig, EvalConfig
52
+ from nnx_ppo.networks.factories import make_mlp_actor_critic
53
+ from nnx_ppo.wrappers import episode_wrapper
54
+
55
+ env = mujoco_playground.registry.load("CartpoleBalance")
56
+ train_env = episode_wrapper.EpisodeWrapper(env, 1000)
57
+
58
+ rngs = nnx.Rngs(0)
59
+ nets = make_mlp_actor_critic(
60
+ env.observation_size,
61
+ env.action_size,
62
+ actor_hidden_sizes=[64] * 4,
63
+ critic_hidden_sizes=[256] * 2,
64
+ rngs=rngs,
65
+ normalize_obs=True,
66
+ )
67
+
68
+ result = ppo.train_ppo(
69
+ train_env,
70
+ nets,
71
+ TrainConfig(
72
+ ppo=PPOConfig(n_envs=1024, rollout_length=30, total_steps=10_000_000),
73
+ eval=EvalConfig(enabled=True, every_steps=500_000, n_envs=64,
74
+ max_episode_length=1000),
75
+ ),
76
+ eval_env=env,
77
+ )
78
+ print(f"Final eval reward: {result.eval_history[-1]['episode_reward_mean']}")
79
+ ```
80
+
81
+ See [examples/wandb_logging.py](examples/wandb_logging.py) for a complete
82
+ training script with W&B logging and video recording.
83
+
84
+ ## Documentation
85
+
86
+ Full documentation, tutorials, and API reference are at
87
+ <https://nnx-ppo.readthedocs.io>.
88
+
89
+ ## License
90
+
91
+ BSD 3-Clause — see [LICENSE](LICENSE).
@@ -0,0 +1 @@
1
+ __version__ = "0.2.1"
File without changes
@@ -0,0 +1,35 @@
1
+ """Callback helpers for training logging."""
2
+
3
+ from collections.abc import Callable
4
+
5
+ from nnx_ppo.algorithms.config import VideoData
6
+
7
+
8
+ def wandb_video_fn(
9
+ key: str = "eval_video", fps: int = 30
10
+ ) -> Callable[[VideoData], None]:
11
+ """Create a video callback that logs to wandb.
12
+
13
+ Args:
14
+ key: The wandb log key for the video.
15
+ fps: Frames per second for the video.
16
+
17
+ Returns:
18
+ A callback function compatible with train_ppo's video_fn parameter.
19
+
20
+ Example:
21
+ >>> result = train_ppo(
22
+ ... env, networks,
23
+ ... video_fn=wandb_video_fn(fps=50),
24
+ ... )
25
+ """
26
+ import wandb
27
+
28
+ def video_fn(data: VideoData) -> None:
29
+ # Convert THWC to TCHW for wandb
30
+ video_array = data.frames.transpose(0, 3, 1, 2)
31
+ wandb.log(
32
+ {key: wandb.Video(video_array, fps=fps, format="mp4")}, step=data.step
33
+ )
34
+
35
+ return video_fn
@@ -0,0 +1,204 @@
1
+ """Checkpointing utilities for saving and loading training state."""
2
+
3
+ import os
4
+ import pickle
5
+ from collections.abc import Callable
6
+ from typing import Any, Optional, Protocol, runtime_checkable
7
+
8
+ import jax
9
+ from flax import nnx
10
+
11
+ from nnx_ppo.algorithms.config import TrainConfig
12
+ from nnx_ppo.algorithms.types import TrainingState
13
+
14
+
15
+ @runtime_checkable
16
+ class CheckpointCallback(Protocol):
17
+ """Protocol for checkpoint callbacks with named parameters."""
18
+
19
+ def __call__(self, training_state: TrainingState, step: int) -> None: ...
20
+
21
+
22
+ def _split_net_state(networks):
23
+ """Split network state: RngKey → pickle, everything else → orbax.
24
+
25
+ orbax cannot handle JAX new-style PRNG key arrays (dtype ``key<fry>``), so
26
+ we separate nnx.RngKey variables and persist them with pickle instead. All
27
+ other variable types — including nnx.Param, nnx.RngCount, and custom
28
+ Variable subclasses such as NormalizerStatistics — are saved via orbax.
29
+
30
+ Returns:
31
+ (non_key_state, rng_key_state, abstract_non_key) — the first two are
32
+ nnx.State objects and the third is the abstract (ShapeDtypeStruct)
33
+ target needed for orbax restoration.
34
+ """
35
+ _, rng_key_state, non_key_state = nnx.split(networks, nnx.RngKey, ...)
36
+ abstract_non_key = jax.tree_util.tree_map(
37
+ lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype), non_key_state
38
+ )
39
+ return non_key_state, rng_key_state, abstract_non_key
40
+
41
+
42
+ def make_checkpoint_fn(
43
+ directory: str,
44
+ config: Optional[TrainConfig] = None,
45
+ ) -> CheckpointCallback:
46
+ """Create a checkpoint callback that saves TrainingState to disk.
47
+
48
+ Each checkpoint is written to ``{directory}/step_{step:010d}/``, containing:
49
+
50
+ - ``networks/`` — orbax checkpoint with all non-PRNG-key network variables
51
+ (Param, RngCount, NormalizerStatistics, etc.)
52
+ - ``optimizer/`` — orbax checkpoint with all optimizer state arrays
53
+ - ``metadata.pkl`` — pickle file with network RngKey variables,
54
+ all remaining TrainingState fields (``network_states``, ``env_states``,
55
+ ``rng_key``, ``steps_taken``), the step count, and the optional
56
+ TrainConfig.
57
+
58
+ To resume training from a checkpoint, use :func:`load_checkpoint`.
59
+
60
+ Args:
61
+ directory: Base directory under which checkpoint subdirectories are
62
+ created.
63
+ config: Optional TrainConfig to store alongside each checkpoint, useful
64
+ for reproducing training runs.
65
+
66
+ Returns:
67
+ A callback compatible with train_ppo's ``checkpoint_fn`` parameter.
68
+
69
+ Example:
70
+ >>> result = train_ppo(
71
+ ... env, networks, config,
72
+ ... checkpoint_fn=make_checkpoint_fn("/tmp/my_run", config=config),
73
+ ... )
74
+ """
75
+
76
+ abs_directory = os.path.abspath(directory)
77
+
78
+ def checkpoint_fn(training_state: TrainingState, step: int) -> None:
79
+ import orbax.checkpoint as ocp
80
+
81
+ step_dir = os.path.join(abs_directory, f"step_{step:010d}")
82
+ os.makedirs(step_dir, exist_ok=True)
83
+
84
+ # Split network state: everything except RngKey → orbax, RngKey → pickle.
85
+ # orbax cannot handle JAX new-style PRNG key arrays.
86
+ non_key_state, rng_key_state, _ = _split_net_state(training_state.networks)
87
+
88
+ # The optimizer only contains float/int arrays; no key arrays.
89
+ _, opt_state = nnx.split(training_state.optimizer)
90
+
91
+ # Save parameter arrays with orbax. A fresh checkpointer is created per
92
+ # call and immediately closed to ensure all async writes complete.
93
+ checkpointer = ocp.StandardCheckpointer()
94
+ try:
95
+ checkpointer.save(os.path.join(step_dir, "networks"), non_key_state)
96
+ checkpointer.save(os.path.join(step_dir, "optimizer"), opt_state)
97
+ finally:
98
+ checkpointer.close()
99
+
100
+ # Save everything else with pickle (JAX arrays including PRNG keys are
101
+ # pickle-safe).
102
+ metadata = {
103
+ "networks_rng_key_state": rng_key_state,
104
+ "network_states": training_state.network_states,
105
+ "env_states": training_state.env_states,
106
+ "rng_key": training_state.rng_key,
107
+ "steps_taken": training_state.steps_taken,
108
+ "step": step,
109
+ "config": config,
110
+ }
111
+ with open(os.path.join(step_dir, "metadata.pkl"), "wb") as f:
112
+ pickle.dump(metadata, f)
113
+
114
+ return checkpoint_fn
115
+
116
+
117
+ def load_checkpoint(
118
+ path: str,
119
+ networks: Any,
120
+ optimizer: nnx.Optimizer,
121
+ ) -> dict[str, Any]:
122
+ """Load a checkpoint saved by :func:`make_checkpoint_fn`.
123
+
124
+ The ``networks`` and ``optimizer`` arguments serve as structural templates:
125
+ their architecture must match the checkpoint, but their current parameter
126
+ values are irrelevant and will be overwritten in-place by the checkpoint
127
+ values.
128
+
129
+ Args:
130
+ path: Path to the step checkpoint directory, e.g.
131
+ ``/tmp/my_run/step_0000500000``.
132
+ networks: Network instance with the same architecture as the checkpoint.
133
+ Weights are updated in-place.
134
+ optimizer: Optimizer instance with the same structure as the checkpoint.
135
+ State is updated in-place.
136
+
137
+ Returns:
138
+ A dict with the following keys:
139
+
140
+ - ``"training_state"`` — restored :class:`TrainingState`
141
+ - ``"step"`` — training step at which the checkpoint was saved (int)
142
+ - ``"config"`` — :class:`TrainConfig` if one was stored, else ``None``
143
+
144
+ Example:
145
+ >>> networks = factories.make_mlp_actor_critic(...)
146
+ >>> training_state = ppo.new_training_state(env, networks, n_envs, seed)
147
+ >>> ckpt = load_checkpoint(
148
+ ... "/tmp/my_run/step_0000500000",
149
+ ... training_state.networks,
150
+ ... training_state.optimizer,
151
+ ... )
152
+ >>> result = train_ppo(
153
+ ... env, networks, ckpt["config"],
154
+ ... initial_state=ckpt["training_state"],
155
+ ... )
156
+ """
157
+ import orbax.checkpoint as ocp
158
+
159
+ path = os.path.abspath(path)
160
+
161
+ # Build abstract targets from the user-provided templates.
162
+ # Use ... to capture remaining variables (RngKey) that we restore via pickle.
163
+ _, _, abstract_non_key = nnx.split(networks, nnx.RngKey, ...)
164
+ abstract_non_key = jax.tree_util.tree_map(
165
+ lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype), abstract_non_key
166
+ )
167
+ _, opt_template = nnx.split(optimizer)
168
+ opt_abstract = jax.tree_util.tree_map(
169
+ lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype), opt_template
170
+ )
171
+
172
+ checkpointer = ocp.StandardCheckpointer()
173
+ try:
174
+ restored_non_key = checkpointer.restore(
175
+ os.path.join(path, "networks"), abstract_non_key
176
+ )
177
+ restored_opt = checkpointer.restore(
178
+ os.path.join(path, "optimizer"), opt_abstract
179
+ )
180
+ finally:
181
+ checkpointer.close()
182
+
183
+ with open(os.path.join(path, "metadata.pkl"), "rb") as f:
184
+ metadata = pickle.load(f)
185
+
186
+ # Merge orbax-restored non-key state with pickled rng-key state,
187
+ # then update the provided modules in-place.
188
+ full_net_state = nnx.merge_state(restored_non_key, metadata["networks_rng_key_state"])
189
+ nnx.update(networks, full_net_state)
190
+ nnx.update(optimizer, restored_opt)
191
+
192
+ training_state = TrainingState(
193
+ networks=networks,
194
+ network_states=metadata["network_states"],
195
+ env_states=metadata["env_states"],
196
+ optimizer=optimizer,
197
+ rng_key=metadata["rng_key"],
198
+ steps_taken=metadata["steps_taken"],
199
+ )
200
+ return {
201
+ "training_state": training_state,
202
+ "step": metadata["step"],
203
+ "config": metadata["config"],
204
+ }