cusrl 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. cusrl/__init__.py +107 -0
  2. cusrl/environment/__init__.py +11 -0
  3. cusrl/environment/gym.py +157 -0
  4. cusrl/environment/isaaclab.py +133 -0
  5. cusrl/hook/__init__.py +58 -0
  6. cusrl/hook/advantage.py +100 -0
  7. cusrl/hook/condition.py +57 -0
  8. cusrl/hook/gae.py +143 -0
  9. cusrl/hook/gradient.py +48 -0
  10. cusrl/hook/initialization.py +94 -0
  11. cusrl/hook/lr_schedule.py +178 -0
  12. cusrl/hook/normalization.py +194 -0
  13. cusrl/hook/on_policy.py +35 -0
  14. cusrl/hook/ppo.py +77 -0
  15. cusrl/hook/representation.py +132 -0
  16. cusrl/hook/rnd.py +66 -0
  17. cusrl/hook/schedule.py +114 -0
  18. cusrl/hook/smoothness.py +75 -0
  19. cusrl/hook/statistics.py +28 -0
  20. cusrl/hook/symmetry.py +233 -0
  21. cusrl/hook/value.py +158 -0
  22. cusrl/launch/export.py +43 -0
  23. cusrl/launch/play.py +45 -0
  24. cusrl/launch/train.py +62 -0
  25. cusrl/logger/__init__.py +5 -0
  26. cusrl/logger/make_factory.py +18 -0
  27. cusrl/logger/tensorboard_logger.py +28 -0
  28. cusrl/logger/wandb_logger.py +68 -0
  29. cusrl/module/__init__.py +39 -0
  30. cusrl/module/actor.py +203 -0
  31. cusrl/module/attention.py +614 -0
  32. cusrl/module/bijector.py +115 -0
  33. cusrl/module/cnn.py +75 -0
  34. cusrl/module/critic.py +73 -0
  35. cusrl/module/distribution.py +263 -0
  36. cusrl/module/inference.py +57 -0
  37. cusrl/module/mlp.py +63 -0
  38. cusrl/module/module.py +182 -0
  39. cusrl/module/normalization.py +59 -0
  40. cusrl/module/rnn.py +167 -0
  41. cusrl/module/sequential.py +70 -0
  42. cusrl/module/simba.py +70 -0
  43. cusrl/preset/__init__.py +5 -0
  44. cusrl/preset/ppo.py +216 -0
  45. cusrl/sampler/__init__.py +11 -0
  46. cusrl/sampler/mini_batch_sampler.py +78 -0
  47. cusrl/template/__init__.py +27 -0
  48. cusrl/template/actor_critic.py +321 -0
  49. cusrl/template/agent.py +259 -0
  50. cusrl/template/buffer.py +271 -0
  51. cusrl/template/environment.py +208 -0
  52. cusrl/template/hook.py +244 -0
  53. cusrl/template/logger.py +76 -0
  54. cusrl/template/optimizer.py +68 -0
  55. cusrl/template/player.py +114 -0
  56. cusrl/template/trainer.py +290 -0
  57. cusrl/template/trial.py +103 -0
  58. cusrl/utils/__init__.py +30 -0
  59. cusrl/utils/cli.py +59 -0
  60. cusrl/utils/config.py +75 -0
  61. cusrl/utils/distributed.py +146 -0
  62. cusrl/utils/export.py +98 -0
  63. cusrl/utils/helper.py +122 -0
  64. cusrl/utils/metrics.py +72 -0
  65. cusrl/utils/nest.py +82 -0
  66. cusrl/utils/normalizer.py +276 -0
  67. cusrl/utils/recurrent.py +163 -0
  68. cusrl/utils/timing.py +63 -0
  69. cusrl/utils/typing.py +45 -0
  70. cusrl/utils/video.py +21 -0
  71. cusrl/zoo/__init__.py +8 -0
  72. cusrl/zoo/experiment.py +105 -0
  73. cusrl/zoo/gym/__init__.py +2 -0
  74. cusrl/zoo/gym/box2d.py +63 -0
  75. cusrl/zoo/gym/classic_control.py +142 -0
  76. cusrl/zoo/isaaclab/__init__.py +2 -0
  77. cusrl/zoo/isaaclab/classic.py +69 -0
  78. cusrl/zoo/isaaclab/locomotion.py +93 -0
  79. cusrl/zoo/registry.py +70 -0
  80. cusrl-1.0.0.dist-info/METADATA +109 -0
  81. cusrl-1.0.0.dist-info/RECORD +83 -0
  82. cusrl-1.0.0.dist-info/WHEEL +5 -0
  83. cusrl-1.0.0.dist-info/top_level.txt +1 -0
cusrl/__init__.py ADDED
@@ -0,0 +1,107 @@
1
+ from cusrl import environment, hook, logger, module, preset, sampler, template, utils, zoo
2
+ from cusrl.environment import make_gym_env, make_gym_vec, make_isaaclab_env
3
+ from cusrl.module import (
4
+ CNN,
5
+ MLP,
6
+ RNN,
7
+ Actor,
8
+ AdaptiveNormalDist,
9
+ Denormalization,
10
+ Distribution,
11
+ DistributionFactoryLike,
12
+ FeedForward,
13
+ InferenceModule,
14
+ LayerFactoryLike,
15
+ Module,
16
+ ModuleFactory,
17
+ ModuleFactoryLike,
18
+ MultiheadSelfAttention,
19
+ NormalDist,
20
+ Normalization,
21
+ OneHotCategoricalDist,
22
+ Sequential,
23
+ Simba,
24
+ TransformerEncoderLayer,
25
+ Value,
26
+ )
27
+ from cusrl.sampler import (
28
+ AutoMiniBatchSampler,
29
+ MiniBatchSampler,
30
+ TemporalMiniBatchSampler,
31
+ )
32
+ from cusrl.template import (
33
+ ActorCritic,
34
+ Agent,
35
+ Buffer,
36
+ Environment,
37
+ EnvironmentSpec,
38
+ Hook,
39
+ Logger,
40
+ LoggerFactory,
41
+ LoggerFactoryLike,
42
+ OptimizerFactory,
43
+ Player,
44
+ Sampler,
45
+ Trainer,
46
+ Trial,
47
+ )
48
+ from cusrl.utils import (
49
+ device,
50
+ set_global_seed,
51
+ )
52
+
53
+ __all__ = [
54
+ "hook",
55
+ "logger",
56
+ "module",
57
+ "preset",
58
+ "sampler",
59
+ "template",
60
+ "utils",
61
+ "zoo",
62
+ "environment",
63
+ "Actor",
64
+ "ActorCritic",
65
+ "AdaptiveNormalDist",
66
+ "Agent",
67
+ "AutoMiniBatchSampler",
68
+ "Buffer",
69
+ "CNN",
70
+ "Denormalization",
71
+ "Distribution",
72
+ "DistributionFactoryLike",
73
+ "Environment",
74
+ "EnvironmentSpec",
75
+ "FeedForward",
76
+ "Hook",
77
+ "InferenceModule",
78
+ "LayerFactoryLike",
79
+ "Logger",
80
+ "LoggerFactory",
81
+ "LoggerFactoryLike",
82
+ "MLP",
83
+ "MiniBatchSampler",
84
+ "Module",
85
+ "ModuleFactory",
86
+ "ModuleFactoryLike",
87
+ "MultiheadSelfAttention",
88
+ "NormalDist",
89
+ "Normalization",
90
+ "OneHotCategoricalDist",
91
+ "OptimizerFactory",
92
+ "Player",
93
+ "RNN",
94
+ "Sampler",
95
+ "Sequential",
96
+ "Simba",
97
+ "TemporalMiniBatchSampler",
98
+ "Trainer",
99
+ "TransformerEncoderLayer",
100
+ "Trial",
101
+ "Value",
102
+ "device",
103
+ "make_gym_env",
104
+ "make_gym_vec",
105
+ "make_isaaclab_env",
106
+ "set_global_seed",
107
+ ]
@@ -0,0 +1,11 @@
1
+ from .gym import GymEnvAdapter, GymVectorEnvAdapter, make_gym_env, make_gym_vec
2
+ from .isaaclab import IsaacLabEnvAdapter, make_isaaclab_env
3
+
4
+ __all__ = [
5
+ "GymEnvAdapter",
6
+ "GymVectorEnvAdapter",
7
+ "IsaacLabEnvAdapter",
8
+ "make_gym_env",
9
+ "make_gym_vec",
10
+ "make_isaaclab_env",
11
+ ]
@@ -0,0 +1,157 @@
1
+ import random
2
+ import warnings
3
+ from collections.abc import Callable, Sequence
4
+ from typing import Any
5
+
6
+ import gymnasium as gym
7
+ import numpy as np
8
+ from gymnasium.envs.registration import EnvSpec
9
+
10
+ import cusrl.utils
11
+ from cusrl.template import Environment
12
+ from cusrl.utils.typing import Array, Slice
13
+
14
+ __all__ = ["GymEnvAdapter", "GymVectorEnvAdapter", "make_gym_env", "make_gym_vec"]
15
+
16
+
17
+ class GymEnvAdapter(Environment):
18
+ def __init__(self, wrapped: gym.Env):
19
+ if not isinstance(wrapped.observation_space, gym.spaces.Box):
20
+ raise ValueError("Only Box observation space is supported.")
21
+ if not len(wrapped.observation_space.shape) == 1:
22
+ raise ValueError("Only 1D observation space is supported.")
23
+ if isinstance(wrapped.action_space, gym.spaces.Box):
24
+ if not len(wrapped.action_space.shape) == 1:
25
+ raise ValueError("For Box action space, only 1D action space is supported.")
26
+ action_dim = wrapped.action_space.shape[0]
27
+ elif isinstance(wrapped.action_space, gym.spaces.Discrete):
28
+ action_dim = int(wrapped.action_space.n)
29
+ else:
30
+ raise ValueError(f"Unsupported action space type: {wrapped.action_space}.")
31
+
32
+ super().__init__(
33
+ num_instances=1,
34
+ observation_dim=wrapped.observation_space.shape[0],
35
+ action_dim=action_dim,
36
+ spec={"gym_spec": wrapped.spec, "gym_metadata": wrapped.metadata},
37
+ )
38
+ wrapped.reset(seed=random.getrandbits(4))
39
+ self.wrapped = wrapped
40
+
41
+ def reset(self, *, indices: Array | Slice | None = None):
42
+ observation, info = self.wrapped.reset()
43
+ observation = observation.reshape(1, -1)
44
+ if self.wrapped.render_mode is not None:
45
+ self.wrapped.render()
46
+ # TODO: process arrays in info
47
+ return observation, None, info
48
+
49
+ def step(self, action: Array):
50
+ if isinstance(self.wrapped.action_space, gym.spaces.Discrete):
51
+ action = np.argmax(action, axis=-1)
52
+ action = action.squeeze(0)
53
+ observation, reward, terminated, truncated, info = self.wrapped.step(action)
54
+ observation = observation.reshape(1, -1)
55
+ reward = np.array([[reward]], dtype=np.float32)
56
+ terminated = np.array([[terminated]])
57
+ truncated = np.array([[truncated]])
58
+ if self.wrapped.render_mode is not None:
59
+ self.wrapped.render()
60
+ # TODO: process arrays in info
61
+ return observation, None, reward, terminated, truncated, info
62
+
63
+
64
+ class GymVectorEnvAdapter(Environment):
65
+ def __init__(self, wrapped: gym.vector.VectorEnv):
66
+ if not isinstance(wrapped.single_observation_space, gym.spaces.Box):
67
+ raise ValueError("Only Box observation space is supported.")
68
+ if not len(wrapped.single_observation_space.shape) == 1:
69
+ raise ValueError("Only 1D observation space is supported.")
70
+ if isinstance(wrapped.single_action_space, gym.spaces.Box):
71
+ if not len(wrapped.single_action_space.shape) == 1:
72
+ raise ValueError("For Box action space, only 1D action space is supported.")
73
+ action_dim = wrapped.single_action_space.shape[0]
74
+ elif isinstance(wrapped.single_action_space, gym.spaces.Discrete):
75
+ action_dim = int(wrapped.single_action_space.n)
76
+ else:
77
+ raise ValueError(f"Unsupported action space type: {wrapped.single_action_space}.")
78
+
79
+ if (autoreset_mode := wrapped.metadata.get("autoreset_mode")) is None:
80
+ if cusrl.utils.is_main_process():
81
+ warnings.warn("GymVectorEnvAdapter: make sure 'autoreset_mode' is 'DISABLED'.")
82
+ elif autoreset_mode != gym.vector.AutoresetMode.DISABLED:
83
+ raise ValueError("'autoreset_mode' of vector environments must be 'DISABLED'.")
84
+
85
+ super().__init__(
86
+ num_instances=wrapped.num_envs,
87
+ observation_dim=wrapped.single_observation_space.shape[0],
88
+ action_dim=action_dim,
89
+ spec={"gym_spec": wrapped.spec, "gym_metadata": wrapped.metadata},
90
+ )
91
+ wrapped.reset(seed=random.getrandbits(4))
92
+ self.wrapped = wrapped
93
+
94
+ def reset(self, *, indices: Array | Slice | None = None):
95
+ if indices is None:
96
+ observation, info = self.wrapped.reset()
97
+ return observation, None, info
98
+ if not isinstance(indices, np.ndarray):
99
+ mask = np.zeros(self.num_instances, dtype=bool)
100
+ mask[indices] = True
101
+ indices = mask
102
+
103
+ observation, info = self.wrapped.reset(options={"reset_mask": indices})
104
+ if self.wrapped.render_mode is not None:
105
+ self.wrapped.render()
106
+ # TODO: process arrays in info
107
+ return observation, None, info
108
+
109
+ def step(self, action: Array):
110
+ if isinstance(self.wrapped.single_action_space, gym.spaces.Discrete):
111
+ action = np.argmax(action, axis=-1)
112
+ observation, reward, terminated, truncated, info = self.wrapped.step(action)
113
+ if isinstance(reward, np.ndarray):
114
+ reward = reward.astype(np.float32)
115
+ reward = reward.reshape(-1, 1)
116
+ terminated = terminated.reshape(-1, 1)
117
+ truncated = truncated.reshape(-1, 1)
118
+ if self.wrapped.render_mode is not None:
119
+ self.wrapped.render()
120
+ # TODO: process arrays in info
121
+ return observation, None, reward, terminated, truncated, info
122
+
123
+
124
+ def make_gym_env(
125
+ id: str | EnvSpec,
126
+ max_episode_steps: int | None = None,
127
+ disable_env_checker: bool | None = None,
128
+ **kwargs: Any,
129
+ ) -> Environment:
130
+ return GymEnvAdapter(
131
+ gym.make(
132
+ id=id,
133
+ max_episode_steps=max_episode_steps,
134
+ disable_env_checker=disable_env_checker,
135
+ **kwargs,
136
+ )
137
+ )
138
+
139
+
140
+ def make_gym_vec(
141
+ id: str | EnvSpec,
142
+ num_envs: int = 1,
143
+ vectorization_mode: gym.VectorizeMode | str | None = None,
144
+ vector_kwargs: dict[str, Any] | None = None,
145
+ wrappers: Sequence[Callable[[gym.Env], gym.Wrapper]] | None = None,
146
+ **kwargs,
147
+ ) -> Environment:
148
+ return GymVectorEnvAdapter(
149
+ gym.make_vec(
150
+ id=id,
151
+ num_envs=num_envs,
152
+ vectorization_mode=vectorization_mode,
153
+ vector_kwargs=(vector_kwargs or {}) | {"autoreset_mode": gym.vector.AutoresetMode.DISABLED},
154
+ wrappers=wrappers,
155
+ **kwargs,
156
+ )
157
+ )
@@ -0,0 +1,133 @@
1
+ import argparse
2
+ from collections.abc import Sequence
3
+ from typing import Any
4
+
5
+ import gymnasium as gym
6
+ import torch
7
+
8
+ import cusrl.utils
9
+ from cusrl.template import Environment, EnvironmentSpec
10
+ from cusrl.utils.typing import Array, Slice
11
+
12
+ __all__ = ["IsaacLabEnvAdapter", "make_isaaclab_env"]
13
+
14
+
15
+ class IsaacLabEnvAdapter(Environment):
16
+ def __init__(self, id: str, argv: Sequence[str] | None = None, **kwargs):
17
+ from isaaclab.app import AppLauncher
18
+
19
+ parser = argparse.ArgumentParser(prog="--environment-args", description="IsaacLab environment")
20
+ parser.add_argument("--num_envs", type=int, metavar="N", help="Number of environments to simulate.")
21
+ AppLauncher.add_app_launcher_args(parser)
22
+ args = parser.parse_args(argv or [])
23
+ args.device = str(cusrl.device())
24
+ self.app_launcher = AppLauncher(args)
25
+ self.simulation_app = self.app_launcher.app
26
+
27
+ from isaaclab.envs import DirectMARLEnv, DirectRLEnv, ManagerBasedRLEnv, multi_agent_to_single_agent
28
+ from isaaclab_tasks.utils.parse_cfg import load_cfg_from_registry
29
+
30
+ env_cfg = load_cfg_from_registry(id, "env_cfg_entry_point")
31
+ env_cfg.sim.device = args.device
32
+ if args.num_envs is not None:
33
+ env_cfg.scene.num_envs = args.num_envs
34
+ env_cfg.scene.num_envs = max(env_cfg.scene.num_envs // cusrl.utils.distributed.world_size(), 1)
35
+ isaaclab_env = gym.make(id, cfg=env_cfg, disable_env_checker=True, **kwargs)
36
+ if isinstance(isaaclab_env.unwrapped, DirectMARLEnv):
37
+ isaaclab_env = multi_agent_to_single_agent(isaaclab_env)
38
+ self.wrapped: ManagerBasedRLEnv | DirectRLEnv = isaaclab_env.unwrapped
39
+ self.device = self.wrapped.device
40
+ self.metrics = cusrl.utils.Metrics()
41
+ super().__init__(
42
+ self.wrapped.num_envs,
43
+ self._get_observation_dim(),
44
+ self._get_action_dim(),
45
+ self._get_state_dim(),
46
+ EnvironmentSpec(autoreset=True, final_state_is_missing=True),
47
+ )
48
+
49
+ # Avoid terminal color issues
50
+ print("\033[0m", end="")
51
+
52
+ def __del__(self):
53
+ if hasattr(self, "wrapped"):
54
+ self.wrapped.close()
55
+ if hasattr(self, "simulation_app"):
56
+ self.simulation_app.close()
57
+
58
+ def _get_observation_dim(self) -> int:
59
+ if hasattr(self.wrapped, "observation_manager"):
60
+ shape = self.wrapped.observation_manager.group_obs_dim["policy"]
61
+ else:
62
+ shape = self.wrapped.single_observation_space["policy"].shape
63
+
64
+ if not len(shape) == 1:
65
+ raise ValueError("Only 1D observation space is supported. ")
66
+ return shape[0]
67
+
68
+ def _get_action_dim(self) -> int:
69
+ if hasattr(self.wrapped, "action_manager"):
70
+ return self.wrapped.action_manager.total_action_dim
71
+ space = self.wrapped.single_action_space
72
+ if not len(space.shape) == 1:
73
+ raise ValueError("Only 1D action space is supported. ")
74
+ return space.shape[0]
75
+
76
+ def _get_state_dim(self) -> int | None:
77
+ shape = None
78
+ if hasattr(self.wrapped, "observation_manager"):
79
+ shape = self.wrapped.observation_manager.group_obs_dim.get("critic")
80
+ else:
81
+ space = self.wrapped.single_observation_space.get("critic")
82
+ if space is not None:
83
+ shape = space.shape
84
+
85
+ if shape is None:
86
+ return None
87
+ if not len(shape) == 1:
88
+ raise ValueError("Only 1D state space is supported. ")
89
+ return shape[0]
90
+
91
+ def reset(self, *, indices: Array | Slice | None = None):
92
+ if indices is None:
93
+ observation_dict, _ = self.wrapped.reset()
94
+ self.wrapped.episode_length_buf.random_(int(self.wrapped.max_episode_length))
95
+ observation = observation_dict.pop("policy")
96
+ state = observation_dict.pop("critic", None)
97
+ extras = observation_dict
98
+ else:
99
+ if isinstance(indices, slice):
100
+ indices = torch.arange(self.num_instances, device=self.device)[indices]
101
+ observation_dict, _ = self.wrapped.reset(env_ids=torch.as_tensor(indices, device=self.device))
102
+
103
+ observation = observation_dict.pop("policy", None)
104
+ state = observation_dict.pop("critic", None)
105
+ extras = {key: value[indices] for key, value in observation_dict.items()}
106
+ if observation is not None:
107
+ observation = observation[indices]
108
+ if state is not None:
109
+ state = state[indices]
110
+
111
+ return observation, state, extras
112
+
113
+ def step(self, action):
114
+ observation_dict, reward, terminated, truncated, extras = self.wrapped.step(action)
115
+ observation = observation_dict.pop("policy")
116
+ state = observation_dict.pop("critic", None)
117
+ reward = reward.unsqueeze(-1)
118
+ terminated = terminated.unsqueeze(-1)
119
+ truncated = truncated.unsqueeze(-1)
120
+ self.metrics.record(
121
+ **extras.get("log", {}),
122
+ **extras.get("episode", {}),
123
+ )
124
+ return observation, state, reward, terminated, truncated, observation_dict
125
+
126
+ def get_metrics(self):
127
+ metrics = self.metrics.summary()
128
+ self.metrics.clear()
129
+ return metrics
130
+
131
+
132
+ def make_isaaclab_env(id: str, argv: Sequence[str] | None = None, **kwargs: Any) -> Environment:
133
+ return IsaacLabEnvAdapter(id, argv, **kwargs)
cusrl/hook/__init__.py ADDED
@@ -0,0 +1,58 @@
1
+ from .advantage import AdvantageNormalization, AdvantageReduction
2
+ from .condition import ConditionalObjectiveActivation
3
+ from .gae import GAE
4
+ from .gradient import GradientClipping
5
+ from .initialization import ModuleInitialization
6
+ from .lr_schedule import AdaptiveLRSchedule, MiniBatchWiseLRSchedule, ThresholdLRSchedule
7
+ from .normalization import ObservationNormalization
8
+ from .on_policy import OnPolicyPreparation
9
+ from .ppo import EntropyLoss, PPOSurrogateLoss
10
+ from .representation import (
11
+ NextStatePrediction,
12
+ ReturnPrediction,
13
+ StatePrediction,
14
+ )
15
+ from .rnd import RandomNetworkDistillation
16
+ from .schedule import (
17
+ HookActivationSchedule,
18
+ OnPolicyBufferCapacitySchedule,
19
+ ParameterSchedule,
20
+ )
21
+ from .smoothness import ActionSmoothnessLoss
22
+ from .statistics import OnPolicyStatistics
23
+ from .symmetry import (
24
+ SymmetricArchitecture,
25
+ SymmetricDataAugmentation,
26
+ SymmetryLoss,
27
+ )
28
+ from .value import ValueComputation, ValueLoss
29
+
30
+ __all__ = [
31
+ "ActionSmoothnessLoss",
32
+ "AdaptiveLRSchedule",
33
+ "AdvantageNormalization",
34
+ "AdvantageReduction",
35
+ "ConditionalObjectiveActivation",
36
+ "EntropyLoss",
37
+ "GAE",
38
+ "GradientClipping",
39
+ "HookActivationSchedule",
40
+ "MiniBatchWiseLRSchedule",
41
+ "ModuleInitialization",
42
+ "NextStatePrediction",
43
+ "ObservationNormalization",
44
+ "OnPolicyBufferCapacitySchedule",
45
+ "OnPolicyPreparation",
46
+ "OnPolicyStatistics",
47
+ "ParameterSchedule",
48
+ "PPOSurrogateLoss",
49
+ "RandomNetworkDistillation",
50
+ "ReturnPrediction",
51
+ "StatePrediction",
52
+ "SymmetricArchitecture",
53
+ "SymmetricDataAugmentation",
54
+ "SymmetryLoss",
55
+ "ThresholdLRSchedule",
56
+ "ValueComputation",
57
+ "ValueLoss",
58
+ ]
@@ -0,0 +1,100 @@
1
+ from collections.abc import Sequence
2
+ from typing import Literal
3
+
4
+ import torch
5
+ from torch import Tensor
6
+
7
+ from cusrl.template import ActorCritic, Hook
8
+ from cusrl.utils import distributed
9
+
10
+ __all__ = ["AdvantageReduction", "AdvantageNormalization"]
11
+
12
+
13
+ class AdvantageReduction(Hook):
14
+ """A hook to reduce a multidimensional advantage tensor into a scalar.
15
+
16
+ This hook reduces the advantage tensor along its last dimension. This is useful
17
+ in multi-goal settings where the advantage is a vector.
18
+
19
+ Args:
20
+ reduction (Literal["sum", "mean"], optional):
21
+ The reduction method to apply, either "sum" or "mean".
22
+ Defaults to "sum".
23
+ weight (Sequence[float] | None, optional):
24
+ An optional sequence of weights to apply element-wise to the
25
+ advantage tensor before reduction. Defaults to None.
26
+
27
+ Raises:
28
+ ValueError: If an invalid reduction method is provided.
29
+ """
30
+
31
+ def __init__(
32
+ self,
33
+ reduction: Literal["sum", "mean"] = "sum",
34
+ weight: Sequence[float] | None = None,
35
+ ):
36
+ if reduction not in ("sum", "mean"):
37
+ raise ValueError(f"Unknown reduction: '{reduction}'.")
38
+
39
+ self.reduction = reduction
40
+ self.weight: Tensor | None = weight
41
+
42
+ def init(self):
43
+ if self.weight is not None:
44
+ self.weight = self.agent.to_tensor(self.weight)
45
+
46
+ def objective(self, batch: dict[str, Tensor]):
47
+ if (advantage := batch["advantage"]).size(-1) == 1:
48
+ return
49
+ if self.weight is not None:
50
+ advantage = advantage * self.weight
51
+ if self.reduction == "sum":
52
+ advantage = advantage.sum(-1, keepdim=True)
53
+ elif self.reduction == "mean":
54
+ advantage = advantage.mean(-1, keepdim=True)
55
+ else:
56
+ raise ValueError(f"Unknown reduction: '{self.reduction}'.")
57
+ batch["advantage"] = advantage
58
+
59
+
60
+ class AdvantageNormalization(Hook[ActorCritic]):
61
+ """A hook to normalize advantages in actor-critic algorithms.
62
+
63
+ This hook standardizes the advantages to have a mean of 0 and a standard
64
+ deviation of 1. This can help stabilize training by preventing the scale of
65
+ advantages from fluctuating wildly. Normalization can be configured to occur
66
+ either once on the entire buffer before updates begin, or on each mini-batch
67
+ during the objective calculation.
68
+
69
+ The normalization correctly handles distributed training by averaging the
70
+ mean and variance across all processes.
71
+
72
+ Args:
73
+ mini_batch_wise (bool, optional):
74
+ If `True`, normalization is applied to each mini-batch. Defaults to
75
+ `False`,
76
+ synchronize (bool, optional):
77
+ If `True`, the mean and variance are synchronized across all processes
78
+ in distributed training. Defaults to `True`.
79
+ """
80
+
81
+ def __init__(self, mini_batch_wise: bool = False, synchronize: bool = True):
82
+ self.mini_batch_wise = mini_batch_wise
83
+ self.synchronize = synchronize
84
+
85
+ def pre_update(self, buffer):
86
+ if not self.mini_batch_wise:
87
+ self.normalize_(buffer["advantage"])
88
+
89
+ def objective(self, batch):
90
+ if self.mini_batch_wise:
91
+ self.normalize_(batch["advantage"])
92
+
93
+ @torch.no_grad()
94
+ def normalize_(self, advantage: Tensor):
95
+ dims = tuple(range(advantage.ndim - 1))
96
+ var, mean = torch.var_mean(advantage, dim=dims, correction=0)
97
+ if self.synchronize:
98
+ distributed.reduce_mean_var_(mean, var)
99
+ std = (var + 1e-8).sqrt()
100
+ advantage.sub_(mean).div_(std)
@@ -0,0 +1,57 @@
1
+ from collections.abc import Callable, Iterable
2
+ from typing import Any
3
+
4
+ import torch
5
+
6
+ from cusrl.template import ActorCritic, Buffer, Hook
7
+
8
+ __all__ = ["ConditionalObjectiveActivation", "EpochIndexCondition"]
9
+
10
+
11
+ class EpochIndexCondition:
12
+ """Checks if the current epoch index is in a specified set of epoch indices.
13
+
14
+ Args:
15
+ epoch_index (int | Iterable[int]):
16
+ A single epoch index or an iterable of epoch indices.
17
+ """
18
+
19
+ def __init__(self, epoch_index: int | Iterable[int]):
20
+ if isinstance(epoch_index, int):
21
+ epoch_index = [epoch_index]
22
+ self.epoch_index = set(epoch_index)
23
+
24
+ def __call__(self, agent: ActorCritic, batch: dict[str, Any]) -> bool:
25
+ return batch["epoch_index"] in self.epoch_index
26
+
27
+
28
+ class ConditionalObjectiveActivation(Hook[ActorCritic]):
29
+ """A hook to activate other objective hooks based on specified conditions.
30
+
31
+ This hook must be placed before any objective hooks it controls.
32
+
33
+ Args:
34
+ named_conditions (Callable[[ActorCritic, dict[str, Any]], bool]):
35
+ Keyword arguments mapping the name of an objective hook to a callable
36
+ condition. The condition determines whether the corresponding hook
37
+ should be active. It receives the agent and the current batch and
38
+ returns `True` if the hook should be active, `False` otherwise.
39
+ """
40
+
41
+ def __init__(self, **named_conditions: Callable[[ActorCritic, dict[str, Any]], bool]):
42
+ self.named_conditions = named_conditions
43
+ self.named_activation = {}
44
+
45
+ def pre_update(self, buffer: "Buffer"):
46
+ # Store the current activation state of the hooks
47
+ for name in self.named_conditions:
48
+ self.named_activation[name] = self.agent.hook[name].active
49
+
50
+ def objective(self, batch: dict[str, torch.Tensor]) -> torch.Tensor | None:
51
+ for name, condition in self.named_conditions.items():
52
+ self.agent.hook[name].active = self.named_activation[name] and condition(self.agent, batch)
53
+
54
+ def post_update(self):
55
+ # Restore the activation state of the hooks
56
+ for name in self.named_conditions:
57
+ self.agent.hook[name].active = self.named_activation[name]