nnx-ppo 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nnx_ppo/__init__.py +1 -0
- nnx_ppo/algorithms/__init__.py +0 -0
- nnx_ppo/algorithms/callbacks.py +35 -0
- nnx_ppo/algorithms/checkpointing.py +204 -0
- nnx_ppo/algorithms/checkpointing_test.py +502 -0
- nnx_ppo/algorithms/config.py +127 -0
- nnx_ppo/algorithms/distillation.py +603 -0
- nnx_ppo/algorithms/distillation_test.py +201 -0
- nnx_ppo/algorithms/metrics.py +121 -0
- nnx_ppo/algorithms/ppo.py +572 -0
- nnx_ppo/algorithms/ppo_test.py +486 -0
- nnx_ppo/algorithms/rollout.py +279 -0
- nnx_ppo/algorithms/rollout_test.py +222 -0
- nnx_ppo/algorithms/types.py +150 -0
- nnx_ppo/conftest.py +10 -0
- nnx_ppo/jax_dataclass.py +41 -0
- nnx_ppo/networks/__init__.py +0 -0
- nnx_ppo/networks/adapter.py +133 -0
- nnx_ppo/networks/adapter_test.py +112 -0
- nnx_ppo/networks/ar1_rollout_test.py +288 -0
- nnx_ppo/networks/containers.py +218 -0
- nnx_ppo/networks/containers_test.py +127 -0
- nnx_ppo/networks/delay.py +95 -0
- nnx_ppo/networks/delay_test.py +174 -0
- nnx_ppo/networks/factories.py +146 -0
- nnx_ppo/networks/factories_test.py +123 -0
- nnx_ppo/networks/feedforward.py +51 -0
- nnx_ppo/networks/graph/__init__.py +12 -0
- nnx_ppo/networks/graph/connection.py +34 -0
- nnx_ppo/networks/graph/graph.py +448 -0
- nnx_ppo/networks/graph/graph_test.py +271 -0
- nnx_ppo/networks/graph/population.py +38 -0
- nnx_ppo/networks/normalizer.py +136 -0
- nnx_ppo/networks/normalizer_test.py +181 -0
- nnx_ppo/networks/recurrent.py +161 -0
- nnx_ppo/networks/recurrent_test.py +347 -0
- nnx_ppo/networks/sampling_layers.py +147 -0
- nnx_ppo/networks/types.py +113 -0
- nnx_ppo/networks/utils.py +326 -0
- nnx_ppo/networks/utils_test.py +316 -0
- nnx_ppo/networks/variational.py +216 -0
- nnx_ppo/networks/variational_test.py +178 -0
- nnx_ppo/py.typed +0 -0
- nnx_ppo/test_dummies/__init__.py +0 -0
- nnx_ppo/test_dummies/dict_obs_act_env.py +209 -0
- nnx_ppo/test_dummies/dummy_counter.py +75 -0
- nnx_ppo/test_dummies/mock_env.py +63 -0
- nnx_ppo/test_dummies/move_from_center_env.py +51 -0
- nnx_ppo/test_dummies/move_to_center_env.py +50 -0
- nnx_ppo/test_dummies/parrot_env.py +43 -0
- nnx_ppo/test_dummies/stateful_nets.py +40 -0
- nnx_ppo/wrappers/__init__.py +0 -0
- nnx_ppo/wrappers/episode_wrapper.py +40 -0
- nnx_ppo/wrappers/episode_wrapper_test.py +57 -0
- nnx_ppo/wrappers/reward_scaling_wrapper.py +28 -0
- nnx_ppo-0.2.1.dist-info/METADATA +136 -0
- nnx_ppo-0.2.1.dist-info/RECORD +60 -0
- nnx_ppo-0.2.1.dist-info/WHEEL +5 -0
- nnx_ppo-0.2.1.dist-info/licenses/LICENSE +28 -0
- nnx_ppo-0.2.1.dist-info/top_level.txt +1 -0
nnx_ppo/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = "0.2.1"
|
|
File without changes
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
"""Callback helpers for training logging."""
|
|
2
|
+
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
|
|
5
|
+
from nnx_ppo.algorithms.config import VideoData
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def wandb_video_fn(
|
|
9
|
+
key: str = "eval_video", fps: int = 30
|
|
10
|
+
) -> Callable[[VideoData], None]:
|
|
11
|
+
"""Create a video callback that logs to wandb.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
key: The wandb log key for the video.
|
|
15
|
+
fps: Frames per second for the video.
|
|
16
|
+
|
|
17
|
+
Returns:
|
|
18
|
+
A callback function compatible with train_ppo's video_fn parameter.
|
|
19
|
+
|
|
20
|
+
Example:
|
|
21
|
+
>>> result = train_ppo(
|
|
22
|
+
... env, networks,
|
|
23
|
+
... video_fn=wandb_video_fn(fps=50),
|
|
24
|
+
... )
|
|
25
|
+
"""
|
|
26
|
+
import wandb
|
|
27
|
+
|
|
28
|
+
def video_fn(data: VideoData) -> None:
|
|
29
|
+
# Convert THWC to TCHW for wandb
|
|
30
|
+
video_array = data.frames.transpose(0, 3, 1, 2)
|
|
31
|
+
wandb.log(
|
|
32
|
+
{key: wandb.Video(video_array, fps=fps, format="mp4")}, step=data.step
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
return video_fn
|
|
@@ -0,0 +1,204 @@
|
|
|
1
|
+
"""Checkpointing utilities for saving and loading training state."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import pickle
|
|
5
|
+
from collections.abc import Callable
|
|
6
|
+
from typing import Any, Optional, Protocol, runtime_checkable
|
|
7
|
+
|
|
8
|
+
import jax
|
|
9
|
+
from flax import nnx
|
|
10
|
+
|
|
11
|
+
from nnx_ppo.algorithms.config import TrainConfig
|
|
12
|
+
from nnx_ppo.algorithms.types import TrainingState
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@runtime_checkable
|
|
16
|
+
class CheckpointCallback(Protocol):
|
|
17
|
+
"""Protocol for checkpoint callbacks with named parameters."""
|
|
18
|
+
|
|
19
|
+
def __call__(self, training_state: TrainingState, step: int) -> None: ...
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _split_net_state(networks):
|
|
23
|
+
"""Split network state: RngKey → pickle, everything else → orbax.
|
|
24
|
+
|
|
25
|
+
orbax cannot handle JAX new-style PRNG key arrays (dtype ``key<fry>``), so
|
|
26
|
+
we separate nnx.RngKey variables and persist them with pickle instead. All
|
|
27
|
+
other variable types — including nnx.Param, nnx.RngCount, and custom
|
|
28
|
+
Variable subclasses such as NormalizerStatistics — are saved via orbax.
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
(non_key_state, rng_key_state, abstract_non_key) — the first two are
|
|
32
|
+
nnx.State objects and the third is the abstract (ShapeDtypeStruct)
|
|
33
|
+
target needed for orbax restoration.
|
|
34
|
+
"""
|
|
35
|
+
_, rng_key_state, non_key_state = nnx.split(networks, nnx.RngKey, ...)
|
|
36
|
+
abstract_non_key = jax.tree_util.tree_map(
|
|
37
|
+
lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype), non_key_state
|
|
38
|
+
)
|
|
39
|
+
return non_key_state, rng_key_state, abstract_non_key
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def make_checkpoint_fn(
|
|
43
|
+
directory: str,
|
|
44
|
+
config: Optional[TrainConfig] = None,
|
|
45
|
+
) -> CheckpointCallback:
|
|
46
|
+
"""Create a checkpoint callback that saves TrainingState to disk.
|
|
47
|
+
|
|
48
|
+
Each checkpoint is written to ``{directory}/step_{step:010d}/``, containing:
|
|
49
|
+
|
|
50
|
+
- ``networks/`` — orbax checkpoint with all non-PRNG-key network variables
|
|
51
|
+
(Param, RngCount, NormalizerStatistics, etc.)
|
|
52
|
+
- ``optimizer/`` — orbax checkpoint with all optimizer state arrays
|
|
53
|
+
- ``metadata.pkl`` — pickle file with network RngKey variables,
|
|
54
|
+
all remaining TrainingState fields (``network_states``, ``env_states``,
|
|
55
|
+
``rng_key``, ``steps_taken``), the step count, and the optional
|
|
56
|
+
TrainConfig.
|
|
57
|
+
|
|
58
|
+
To resume training from a checkpoint, use :func:`load_checkpoint`.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
directory: Base directory under which checkpoint subdirectories are
|
|
62
|
+
created.
|
|
63
|
+
config: Optional TrainConfig to store alongside each checkpoint, useful
|
|
64
|
+
for reproducing training runs.
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
A callback compatible with train_ppo's ``checkpoint_fn`` parameter.
|
|
68
|
+
|
|
69
|
+
Example:
|
|
70
|
+
>>> result = train_ppo(
|
|
71
|
+
... env, networks, config,
|
|
72
|
+
... checkpoint_fn=make_checkpoint_fn("/tmp/my_run", config=config),
|
|
73
|
+
... )
|
|
74
|
+
"""
|
|
75
|
+
|
|
76
|
+
abs_directory = os.path.abspath(directory)
|
|
77
|
+
|
|
78
|
+
def checkpoint_fn(training_state: TrainingState, step: int) -> None:
|
|
79
|
+
import orbax.checkpoint as ocp
|
|
80
|
+
|
|
81
|
+
step_dir = os.path.join(abs_directory, f"step_{step:010d}")
|
|
82
|
+
os.makedirs(step_dir, exist_ok=True)
|
|
83
|
+
|
|
84
|
+
# Split network state: everything except RngKey → orbax, RngKey → pickle.
|
|
85
|
+
# orbax cannot handle JAX new-style PRNG key arrays.
|
|
86
|
+
non_key_state, rng_key_state, _ = _split_net_state(training_state.networks)
|
|
87
|
+
|
|
88
|
+
# The optimizer only contains float/int arrays; no key arrays.
|
|
89
|
+
_, opt_state = nnx.split(training_state.optimizer)
|
|
90
|
+
|
|
91
|
+
# Save parameter arrays with orbax. A fresh checkpointer is created per
|
|
92
|
+
# call and immediately closed to ensure all async writes complete.
|
|
93
|
+
checkpointer = ocp.StandardCheckpointer()
|
|
94
|
+
try:
|
|
95
|
+
checkpointer.save(os.path.join(step_dir, "networks"), non_key_state)
|
|
96
|
+
checkpointer.save(os.path.join(step_dir, "optimizer"), opt_state)
|
|
97
|
+
finally:
|
|
98
|
+
checkpointer.close()
|
|
99
|
+
|
|
100
|
+
# Save everything else with pickle (JAX arrays including PRNG keys are
|
|
101
|
+
# pickle-safe).
|
|
102
|
+
metadata = {
|
|
103
|
+
"networks_rng_key_state": rng_key_state,
|
|
104
|
+
"network_states": training_state.network_states,
|
|
105
|
+
"env_states": training_state.env_states,
|
|
106
|
+
"rng_key": training_state.rng_key,
|
|
107
|
+
"steps_taken": training_state.steps_taken,
|
|
108
|
+
"step": step,
|
|
109
|
+
"config": config,
|
|
110
|
+
}
|
|
111
|
+
with open(os.path.join(step_dir, "metadata.pkl"), "wb") as f:
|
|
112
|
+
pickle.dump(metadata, f)
|
|
113
|
+
|
|
114
|
+
return checkpoint_fn
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def load_checkpoint(
|
|
118
|
+
path: str,
|
|
119
|
+
networks: Any,
|
|
120
|
+
optimizer: nnx.Optimizer,
|
|
121
|
+
) -> dict[str, Any]:
|
|
122
|
+
"""Load a checkpoint saved by :func:`make_checkpoint_fn`.
|
|
123
|
+
|
|
124
|
+
The ``networks`` and ``optimizer`` arguments serve as structural templates:
|
|
125
|
+
their architecture must match the checkpoint, but their current parameter
|
|
126
|
+
values are irrelevant and will be overwritten in-place by the checkpoint
|
|
127
|
+
values.
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
path: Path to the step checkpoint directory, e.g.
|
|
131
|
+
``/tmp/my_run/step_0000500000``.
|
|
132
|
+
networks: Network instance with the same architecture as the checkpoint.
|
|
133
|
+
Weights are updated in-place.
|
|
134
|
+
optimizer: Optimizer instance with the same structure as the checkpoint.
|
|
135
|
+
State is updated in-place.
|
|
136
|
+
|
|
137
|
+
Returns:
|
|
138
|
+
A dict with the following keys:
|
|
139
|
+
|
|
140
|
+
- ``"training_state"`` — restored :class:`TrainingState`
|
|
141
|
+
- ``"step"`` — training step at which the checkpoint was saved (int)
|
|
142
|
+
- ``"config"`` — :class:`TrainConfig` if one was stored, else ``None``
|
|
143
|
+
|
|
144
|
+
Example:
|
|
145
|
+
>>> networks = factories.make_mlp_actor_critic(...)
|
|
146
|
+
>>> training_state = ppo.new_training_state(env, networks, n_envs, seed)
|
|
147
|
+
>>> ckpt = load_checkpoint(
|
|
148
|
+
... "/tmp/my_run/step_0000500000",
|
|
149
|
+
... training_state.networks,
|
|
150
|
+
... training_state.optimizer,
|
|
151
|
+
... )
|
|
152
|
+
>>> result = train_ppo(
|
|
153
|
+
... env, networks, ckpt["config"],
|
|
154
|
+
... initial_state=ckpt["training_state"],
|
|
155
|
+
... )
|
|
156
|
+
"""
|
|
157
|
+
import orbax.checkpoint as ocp
|
|
158
|
+
|
|
159
|
+
path = os.path.abspath(path)
|
|
160
|
+
|
|
161
|
+
# Build abstract targets from the user-provided templates.
|
|
162
|
+
# Use ... to capture remaining variables (RngKey) that we restore via pickle.
|
|
163
|
+
_, _, abstract_non_key = nnx.split(networks, nnx.RngKey, ...)
|
|
164
|
+
abstract_non_key = jax.tree_util.tree_map(
|
|
165
|
+
lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype), abstract_non_key
|
|
166
|
+
)
|
|
167
|
+
_, opt_template = nnx.split(optimizer)
|
|
168
|
+
opt_abstract = jax.tree_util.tree_map(
|
|
169
|
+
lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype), opt_template
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
checkpointer = ocp.StandardCheckpointer()
|
|
173
|
+
try:
|
|
174
|
+
restored_non_key = checkpointer.restore(
|
|
175
|
+
os.path.join(path, "networks"), abstract_non_key
|
|
176
|
+
)
|
|
177
|
+
restored_opt = checkpointer.restore(
|
|
178
|
+
os.path.join(path, "optimizer"), opt_abstract
|
|
179
|
+
)
|
|
180
|
+
finally:
|
|
181
|
+
checkpointer.close()
|
|
182
|
+
|
|
183
|
+
with open(os.path.join(path, "metadata.pkl"), "rb") as f:
|
|
184
|
+
metadata = pickle.load(f)
|
|
185
|
+
|
|
186
|
+
# Merge orbax-restored non-key state with pickled rng-key state,
|
|
187
|
+
# then update the provided modules in-place.
|
|
188
|
+
full_net_state = nnx.merge_state(restored_non_key, metadata["networks_rng_key_state"])
|
|
189
|
+
nnx.update(networks, full_net_state)
|
|
190
|
+
nnx.update(optimizer, restored_opt)
|
|
191
|
+
|
|
192
|
+
training_state = TrainingState(
|
|
193
|
+
networks=networks,
|
|
194
|
+
network_states=metadata["network_states"],
|
|
195
|
+
env_states=metadata["env_states"],
|
|
196
|
+
optimizer=optimizer,
|
|
197
|
+
rng_key=metadata["rng_key"],
|
|
198
|
+
steps_taken=metadata["steps_taken"],
|
|
199
|
+
)
|
|
200
|
+
return {
|
|
201
|
+
"training_state": training_state,
|
|
202
|
+
"step": metadata["step"],
|
|
203
|
+
"config": metadata["config"],
|
|
204
|
+
}
|