nnx-ppo 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. nnx_ppo/__init__.py +1 -0
  2. nnx_ppo/algorithms/__init__.py +0 -0
  3. nnx_ppo/algorithms/callbacks.py +35 -0
  4. nnx_ppo/algorithms/checkpointing.py +204 -0
  5. nnx_ppo/algorithms/checkpointing_test.py +502 -0
  6. nnx_ppo/algorithms/config.py +127 -0
  7. nnx_ppo/algorithms/distillation.py +603 -0
  8. nnx_ppo/algorithms/distillation_test.py +201 -0
  9. nnx_ppo/algorithms/metrics.py +121 -0
  10. nnx_ppo/algorithms/ppo.py +572 -0
  11. nnx_ppo/algorithms/ppo_test.py +486 -0
  12. nnx_ppo/algorithms/rollout.py +279 -0
  13. nnx_ppo/algorithms/rollout_test.py +222 -0
  14. nnx_ppo/algorithms/types.py +150 -0
  15. nnx_ppo/conftest.py +10 -0
  16. nnx_ppo/jax_dataclass.py +41 -0
  17. nnx_ppo/networks/__init__.py +0 -0
  18. nnx_ppo/networks/adapter.py +133 -0
  19. nnx_ppo/networks/adapter_test.py +112 -0
  20. nnx_ppo/networks/ar1_rollout_test.py +288 -0
  21. nnx_ppo/networks/containers.py +218 -0
  22. nnx_ppo/networks/containers_test.py +127 -0
  23. nnx_ppo/networks/delay.py +95 -0
  24. nnx_ppo/networks/delay_test.py +174 -0
  25. nnx_ppo/networks/factories.py +146 -0
  26. nnx_ppo/networks/factories_test.py +123 -0
  27. nnx_ppo/networks/feedforward.py +51 -0
  28. nnx_ppo/networks/graph/__init__.py +12 -0
  29. nnx_ppo/networks/graph/connection.py +34 -0
  30. nnx_ppo/networks/graph/graph.py +448 -0
  31. nnx_ppo/networks/graph/graph_test.py +271 -0
  32. nnx_ppo/networks/graph/population.py +38 -0
  33. nnx_ppo/networks/normalizer.py +136 -0
  34. nnx_ppo/networks/normalizer_test.py +181 -0
  35. nnx_ppo/networks/recurrent.py +161 -0
  36. nnx_ppo/networks/recurrent_test.py +347 -0
  37. nnx_ppo/networks/sampling_layers.py +147 -0
  38. nnx_ppo/networks/types.py +113 -0
  39. nnx_ppo/networks/utils.py +326 -0
  40. nnx_ppo/networks/utils_test.py +316 -0
  41. nnx_ppo/networks/variational.py +216 -0
  42. nnx_ppo/networks/variational_test.py +178 -0
  43. nnx_ppo/py.typed +0 -0
  44. nnx_ppo/test_dummies/__init__.py +0 -0
  45. nnx_ppo/test_dummies/dict_obs_act_env.py +209 -0
  46. nnx_ppo/test_dummies/dummy_counter.py +75 -0
  47. nnx_ppo/test_dummies/mock_env.py +63 -0
  48. nnx_ppo/test_dummies/move_from_center_env.py +51 -0
  49. nnx_ppo/test_dummies/move_to_center_env.py +50 -0
  50. nnx_ppo/test_dummies/parrot_env.py +43 -0
  51. nnx_ppo/test_dummies/stateful_nets.py +40 -0
  52. nnx_ppo/wrappers/__init__.py +0 -0
  53. nnx_ppo/wrappers/episode_wrapper.py +40 -0
  54. nnx_ppo/wrappers/episode_wrapper_test.py +57 -0
  55. nnx_ppo/wrappers/reward_scaling_wrapper.py +28 -0
  56. nnx_ppo-0.2.1.dist-info/METADATA +136 -0
  57. nnx_ppo-0.2.1.dist-info/RECORD +60 -0
  58. nnx_ppo-0.2.1.dist-info/WHEEL +5 -0
  59. nnx_ppo-0.2.1.dist-info/licenses/LICENSE +28 -0
  60. nnx_ppo-0.2.1.dist-info/top_level.txt +1 -0
nnx_ppo/__init__.py ADDED
@@ -0,0 +1 @@
1
+ __version__ = "0.2.1"
File without changes
@@ -0,0 +1,35 @@
1
+ """Callback helpers for training logging."""
2
+
3
+ from collections.abc import Callable
4
+
5
+ from nnx_ppo.algorithms.config import VideoData
6
+
7
+
8
+ def wandb_video_fn(
9
+ key: str = "eval_video", fps: int = 30
10
+ ) -> Callable[[VideoData], None]:
11
+ """Create a video callback that logs to wandb.
12
+
13
+ Args:
14
+ key: The wandb log key for the video.
15
+ fps: Frames per second for the video.
16
+
17
+ Returns:
18
+ A callback function compatible with train_ppo's video_fn parameter.
19
+
20
+ Example:
21
+ >>> result = train_ppo(
22
+ ... env, networks,
23
+ ... video_fn=wandb_video_fn(fps=50),
24
+ ... )
25
+ """
26
+ import wandb
27
+
28
+ def video_fn(data: VideoData) -> None:
29
+ # Convert THWC to TCHW for wandb
30
+ video_array = data.frames.transpose(0, 3, 1, 2)
31
+ wandb.log(
32
+ {key: wandb.Video(video_array, fps=fps, format="mp4")}, step=data.step
33
+ )
34
+
35
+ return video_fn
@@ -0,0 +1,204 @@
1
+ """Checkpointing utilities for saving and loading training state."""
2
+
3
+ import os
4
+ import pickle
5
+ from collections.abc import Callable
6
+ from typing import Any, Optional, Protocol, runtime_checkable
7
+
8
+ import jax
9
+ from flax import nnx
10
+
11
+ from nnx_ppo.algorithms.config import TrainConfig
12
+ from nnx_ppo.algorithms.types import TrainingState
13
+
14
+
15
+ @runtime_checkable
16
+ class CheckpointCallback(Protocol):
17
+ """Protocol for checkpoint callbacks with named parameters."""
18
+
19
+ def __call__(self, training_state: TrainingState, step: int) -> None: ...
20
+
21
+
22
+ def _split_net_state(networks):
23
+ """Split network state: RngKey → pickle, everything else → orbax.
24
+
25
+ orbax cannot handle JAX new-style PRNG key arrays (dtype ``key<fry>``), so
26
+ we separate nnx.RngKey variables and persist them with pickle instead. All
27
+ other variable types — including nnx.Param, nnx.RngCount, and custom
28
+ Variable subclasses such as NormalizerStatistics — are saved via orbax.
29
+
30
+ Returns:
31
+ (non_key_state, rng_key_state, abstract_non_key) — the first two are
32
+ nnx.State objects and the third is the abstract (ShapeDtypeStruct)
33
+ target needed for orbax restoration.
34
+ """
35
+ _, rng_key_state, non_key_state = nnx.split(networks, nnx.RngKey, ...)
36
+ abstract_non_key = jax.tree_util.tree_map(
37
+ lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype), non_key_state
38
+ )
39
+ return non_key_state, rng_key_state, abstract_non_key
40
+
41
+
42
+ def make_checkpoint_fn(
43
+ directory: str,
44
+ config: Optional[TrainConfig] = None,
45
+ ) -> CheckpointCallback:
46
+ """Create a checkpoint callback that saves TrainingState to disk.
47
+
48
+ Each checkpoint is written to ``{directory}/step_{step:010d}/``, containing:
49
+
50
+ - ``networks/`` — orbax checkpoint with all non-PRNG-key network variables
51
+ (Param, RngCount, NormalizerStatistics, etc.)
52
+ - ``optimizer/`` — orbax checkpoint with all optimizer state arrays
53
+ - ``metadata.pkl`` — pickle file with network RngKey variables,
54
+ all remaining TrainingState fields (``network_states``, ``env_states``,
55
+ ``rng_key``, ``steps_taken``), the step count, and the optional
56
+ TrainConfig.
57
+
58
+ To resume training from a checkpoint, use :func:`load_checkpoint`.
59
+
60
+ Args:
61
+ directory: Base directory under which checkpoint subdirectories are
62
+ created.
63
+ config: Optional TrainConfig to store alongside each checkpoint, useful
64
+ for reproducing training runs.
65
+
66
+ Returns:
67
+ A callback compatible with train_ppo's ``checkpoint_fn`` parameter.
68
+
69
+ Example:
70
+ >>> result = train_ppo(
71
+ ... env, networks, config,
72
+ ... checkpoint_fn=make_checkpoint_fn("/tmp/my_run", config=config),
73
+ ... )
74
+ """
75
+
76
+ abs_directory = os.path.abspath(directory)
77
+
78
+ def checkpoint_fn(training_state: TrainingState, step: int) -> None:
79
+ import orbax.checkpoint as ocp
80
+
81
+ step_dir = os.path.join(abs_directory, f"step_{step:010d}")
82
+ os.makedirs(step_dir, exist_ok=True)
83
+
84
+ # Split network state: everything except RngKey → orbax, RngKey → pickle.
85
+ # orbax cannot handle JAX new-style PRNG key arrays.
86
+ non_key_state, rng_key_state, _ = _split_net_state(training_state.networks)
87
+
88
+ # The optimizer only contains float/int arrays; no key arrays.
89
+ _, opt_state = nnx.split(training_state.optimizer)
90
+
91
+ # Save parameter arrays with orbax. A fresh checkpointer is created per
92
+ # call and immediately closed to ensure all async writes complete.
93
+ checkpointer = ocp.StandardCheckpointer()
94
+ try:
95
+ checkpointer.save(os.path.join(step_dir, "networks"), non_key_state)
96
+ checkpointer.save(os.path.join(step_dir, "optimizer"), opt_state)
97
+ finally:
98
+ checkpointer.close()
99
+
100
+ # Save everything else with pickle (JAX arrays including PRNG keys are
101
+ # pickle-safe).
102
+ metadata = {
103
+ "networks_rng_key_state": rng_key_state,
104
+ "network_states": training_state.network_states,
105
+ "env_states": training_state.env_states,
106
+ "rng_key": training_state.rng_key,
107
+ "steps_taken": training_state.steps_taken,
108
+ "step": step,
109
+ "config": config,
110
+ }
111
+ with open(os.path.join(step_dir, "metadata.pkl"), "wb") as f:
112
+ pickle.dump(metadata, f)
113
+
114
+ return checkpoint_fn
115
+
116
+
117
+ def load_checkpoint(
118
+ path: str,
119
+ networks: Any,
120
+ optimizer: nnx.Optimizer,
121
+ ) -> dict[str, Any]:
122
+ """Load a checkpoint saved by :func:`make_checkpoint_fn`.
123
+
124
+ The ``networks`` and ``optimizer`` arguments serve as structural templates:
125
+ their architecture must match the checkpoint, but their current parameter
126
+ values are irrelevant and will be overwritten in-place by the checkpoint
127
+ values.
128
+
129
+ Args:
130
+ path: Path to the step checkpoint directory, e.g.
131
+ ``/tmp/my_run/step_0000500000``.
132
+ networks: Network instance with the same architecture as the checkpoint.
133
+ Weights are updated in-place.
134
+ optimizer: Optimizer instance with the same structure as the checkpoint.
135
+ State is updated in-place.
136
+
137
+ Returns:
138
+ A dict with the following keys:
139
+
140
+ - ``"training_state"`` — restored :class:`TrainingState`
141
+ - ``"step"`` — training step at which the checkpoint was saved (int)
142
+ - ``"config"`` — :class:`TrainConfig` if one was stored, else ``None``
143
+
144
+ Example:
145
+ >>> networks = factories.make_mlp_actor_critic(...)
146
+ >>> training_state = ppo.new_training_state(env, networks, n_envs, seed)
147
+ >>> ckpt = load_checkpoint(
148
+ ... "/tmp/my_run/step_0000500000",
149
+ ... training_state.networks,
150
+ ... training_state.optimizer,
151
+ ... )
152
+ >>> result = train_ppo(
153
+ ... env, networks, ckpt["config"],
154
+ ... initial_state=ckpt["training_state"],
155
+ ... )
156
+ """
157
+ import orbax.checkpoint as ocp
158
+
159
+ path = os.path.abspath(path)
160
+
161
+ # Build abstract targets from the user-provided templates.
162
+ # Use ... to capture remaining variables (RngKey) that we restore via pickle.
163
+ _, _, abstract_non_key = nnx.split(networks, nnx.RngKey, ...)
164
+ abstract_non_key = jax.tree_util.tree_map(
165
+ lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype), abstract_non_key
166
+ )
167
+ _, opt_template = nnx.split(optimizer)
168
+ opt_abstract = jax.tree_util.tree_map(
169
+ lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype), opt_template
170
+ )
171
+
172
+ checkpointer = ocp.StandardCheckpointer()
173
+ try:
174
+ restored_non_key = checkpointer.restore(
175
+ os.path.join(path, "networks"), abstract_non_key
176
+ )
177
+ restored_opt = checkpointer.restore(
178
+ os.path.join(path, "optimizer"), opt_abstract
179
+ )
180
+ finally:
181
+ checkpointer.close()
182
+
183
+ with open(os.path.join(path, "metadata.pkl"), "rb") as f:
184
+ metadata = pickle.load(f)
185
+
186
+ # Merge orbax-restored non-key state with pickled rng-key state,
187
+ # then update the provided modules in-place.
188
+ full_net_state = nnx.merge_state(restored_non_key, metadata["networks_rng_key_state"])
189
+ nnx.update(networks, full_net_state)
190
+ nnx.update(optimizer, restored_opt)
191
+
192
+ training_state = TrainingState(
193
+ networks=networks,
194
+ network_states=metadata["network_states"],
195
+ env_states=metadata["env_states"],
196
+ optimizer=optimizer,
197
+ rng_key=metadata["rng_key"],
198
+ steps_taken=metadata["steps_taken"],
199
+ )
200
+ return {
201
+ "training_state": training_state,
202
+ "step": metadata["step"],
203
+ "config": metadata["config"],
204
+ }