cusrl 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cusrl/__init__.py +107 -0
- cusrl/environment/__init__.py +11 -0
- cusrl/environment/gym.py +157 -0
- cusrl/environment/isaaclab.py +133 -0
- cusrl/hook/__init__.py +58 -0
- cusrl/hook/advantage.py +100 -0
- cusrl/hook/condition.py +57 -0
- cusrl/hook/gae.py +143 -0
- cusrl/hook/gradient.py +48 -0
- cusrl/hook/initialization.py +94 -0
- cusrl/hook/lr_schedule.py +178 -0
- cusrl/hook/normalization.py +194 -0
- cusrl/hook/on_policy.py +35 -0
- cusrl/hook/ppo.py +77 -0
- cusrl/hook/representation.py +132 -0
- cusrl/hook/rnd.py +66 -0
- cusrl/hook/schedule.py +114 -0
- cusrl/hook/smoothness.py +75 -0
- cusrl/hook/statistics.py +28 -0
- cusrl/hook/symmetry.py +233 -0
- cusrl/hook/value.py +158 -0
- cusrl/launch/export.py +43 -0
- cusrl/launch/play.py +45 -0
- cusrl/launch/train.py +62 -0
- cusrl/logger/__init__.py +5 -0
- cusrl/logger/make_factory.py +18 -0
- cusrl/logger/tensorboard_logger.py +28 -0
- cusrl/logger/wandb_logger.py +68 -0
- cusrl/module/__init__.py +39 -0
- cusrl/module/actor.py +203 -0
- cusrl/module/attention.py +614 -0
- cusrl/module/bijector.py +115 -0
- cusrl/module/cnn.py +75 -0
- cusrl/module/critic.py +73 -0
- cusrl/module/distribution.py +263 -0
- cusrl/module/inference.py +57 -0
- cusrl/module/mlp.py +63 -0
- cusrl/module/module.py +182 -0
- cusrl/module/normalization.py +59 -0
- cusrl/module/rnn.py +167 -0
- cusrl/module/sequential.py +70 -0
- cusrl/module/simba.py +70 -0
- cusrl/preset/__init__.py +5 -0
- cusrl/preset/ppo.py +216 -0
- cusrl/sampler/__init__.py +11 -0
- cusrl/sampler/mini_batch_sampler.py +78 -0
- cusrl/template/__init__.py +27 -0
- cusrl/template/actor_critic.py +321 -0
- cusrl/template/agent.py +259 -0
- cusrl/template/buffer.py +271 -0
- cusrl/template/environment.py +208 -0
- cusrl/template/hook.py +244 -0
- cusrl/template/logger.py +76 -0
- cusrl/template/optimizer.py +68 -0
- cusrl/template/player.py +114 -0
- cusrl/template/trainer.py +290 -0
- cusrl/template/trial.py +103 -0
- cusrl/utils/__init__.py +30 -0
- cusrl/utils/cli.py +59 -0
- cusrl/utils/config.py +75 -0
- cusrl/utils/distributed.py +146 -0
- cusrl/utils/export.py +98 -0
- cusrl/utils/helper.py +122 -0
- cusrl/utils/metrics.py +72 -0
- cusrl/utils/nest.py +82 -0
- cusrl/utils/normalizer.py +276 -0
- cusrl/utils/recurrent.py +163 -0
- cusrl/utils/timing.py +63 -0
- cusrl/utils/typing.py +45 -0
- cusrl/utils/video.py +21 -0
- cusrl/zoo/__init__.py +8 -0
- cusrl/zoo/experiment.py +105 -0
- cusrl/zoo/gym/__init__.py +2 -0
- cusrl/zoo/gym/box2d.py +63 -0
- cusrl/zoo/gym/classic_control.py +142 -0
- cusrl/zoo/isaaclab/__init__.py +2 -0
- cusrl/zoo/isaaclab/classic.py +69 -0
- cusrl/zoo/isaaclab/locomotion.py +93 -0
- cusrl/zoo/registry.py +70 -0
- cusrl-1.0.0.dist-info/METADATA +109 -0
- cusrl-1.0.0.dist-info/RECORD +83 -0
- cusrl-1.0.0.dist-info/WHEEL +5 -0
- cusrl-1.0.0.dist-info/top_level.txt +1 -0
cusrl/hook/gae.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from cusrl.template import ActorCritic, Hook
|
|
6
|
+
from cusrl.utils import ExponentialMovingNormalizer
|
|
7
|
+
|
|
8
|
+
__all__ = ["GAE"]
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@torch.jit.script
|
|
12
|
+
def _generalized_advantage_estimation(
|
|
13
|
+
reward: torch.Tensor,
|
|
14
|
+
done: torch.Tensor,
|
|
15
|
+
value: torch.Tensor,
|
|
16
|
+
next_value: torch.Tensor,
|
|
17
|
+
gamma: float,
|
|
18
|
+
lamda: float,
|
|
19
|
+
) -> torch.Tensor:
|
|
20
|
+
not_done = done.logical_not()
|
|
21
|
+
advantage = reward + next_value * gamma - value
|
|
22
|
+
for step in range(advantage.size(0) - 2, -1, -1):
|
|
23
|
+
advantage[step] += not_done[step] * (gamma * lamda) * advantage[step + 1]
|
|
24
|
+
return advantage
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class GAE(Hook[ActorCritic]):
|
|
28
|
+
"""A hook that computes advantages and returns using Generalized Advantage Estimation (GAE).
|
|
29
|
+
|
|
30
|
+
GAE is described in:
|
|
31
|
+
"High-Dimensional Continuous Control Using Generalized Advantage Estimation",
|
|
32
|
+
https://arxiv.org/abs/1506.02438
|
|
33
|
+
|
|
34
|
+
Distinct lambda values can be enabled to individually control the bias-variance trade-offs
|
|
35
|
+
for policy and value function, described in:
|
|
36
|
+
"DNA: Proximal Policy Optimization with a Dual Network Architecture"
|
|
37
|
+
https://proceedings.neurips.cc/paper_files/paper/2022/hash/e95475f5fb8edb9075bf9e25670d4013-Abstract-Conference.html
|
|
38
|
+
|
|
39
|
+
PopArt normalization can be applied to the value function, described in:
|
|
40
|
+
"Learning values across many orders of magnitude",
|
|
41
|
+
https://proceedings.neurips.cc/paper/2016/hash/5227b6aaf294f5f027273aebf16015f2-Abstract.html
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
gamma (float, optional):
|
|
45
|
+
Discount factor for future rewards, in [0, 1). Defaults to 0.99.
|
|
46
|
+
lamda (float, optional):
|
|
47
|
+
Smoothing factor for advantage estimation, in [0, 1]. Defaults to 0.95.
|
|
48
|
+
lamda_value (float | None, optional):
|
|
49
|
+
Smoothing factor for value function calculation, in [0, 1].
|
|
50
|
+
If None, the same value as `lamda` is used. Defaults to None.
|
|
51
|
+
recompute (bool, optional):
|
|
52
|
+
If True, recompute advantages and returns after each update. Defaults to False.
|
|
53
|
+
popart_alpha (float | None, optional):
|
|
54
|
+
If not None, applies PopArt normalization to the value function with the specified
|
|
55
|
+
alpha. Defaults to None, which means no normalization is applied.
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
MODULES = ["value_rms"]
|
|
59
|
+
MUTABLE_ATTRS = ["gamma", "lamda"]
|
|
60
|
+
|
|
61
|
+
def __init__(
|
|
62
|
+
self,
|
|
63
|
+
gamma: float = 0.99,
|
|
64
|
+
lamda: float = 0.95,
|
|
65
|
+
lamda_value: float | None = None,
|
|
66
|
+
recompute: bool = False,
|
|
67
|
+
popart_alpha: float | None = None,
|
|
68
|
+
):
|
|
69
|
+
if gamma < 0 or gamma >= 1:
|
|
70
|
+
raise ValueError(f"Invalid gamma value {gamma}, which should be in [0, 1).")
|
|
71
|
+
if lamda < 0 or lamda > 1:
|
|
72
|
+
raise ValueError(f"Invalid lambda value {lamda}, which should be in [0, 1].")
|
|
73
|
+
if lamda_value is not None and (lamda_value < 0 or lamda_value > 1):
|
|
74
|
+
raise ValueError(f"Invalid lambda value for value function {lamda_value}, which should be in [0, 1].")
|
|
75
|
+
|
|
76
|
+
self.gamma = gamma
|
|
77
|
+
self.lamda = lamda
|
|
78
|
+
self.lamda_value = lamda_value
|
|
79
|
+
self.recompute = recompute
|
|
80
|
+
self.popart_alpha = popart_alpha
|
|
81
|
+
self.value_rms: ExponentialMovingNormalizer | None = None
|
|
82
|
+
|
|
83
|
+
def init(self):
|
|
84
|
+
if self.popart_alpha is not None:
|
|
85
|
+
self.value_rms = self.__make_normalizer(self.agent.value_dim)
|
|
86
|
+
self.agent.critic.value_rms = self.__make_normalizer(self.agent.value_dim)
|
|
87
|
+
|
|
88
|
+
def __make_normalizer(self, num_channels: int):
|
|
89
|
+
return ExponentialMovingNormalizer(num_channels, alpha=self.popart_alpha).to(self.agent.device)
|
|
90
|
+
|
|
91
|
+
def pre_update(self, buffer):
|
|
92
|
+
if not self.recompute:
|
|
93
|
+
self._compute_advantage_and_return(buffer)
|
|
94
|
+
|
|
95
|
+
def objective(self, batch: dict[str, Any]):
|
|
96
|
+
if self.recompute:
|
|
97
|
+
self._compute_advantage_and_return(batch)
|
|
98
|
+
|
|
99
|
+
def post_update(self):
|
|
100
|
+
if self.value_rms is not None:
|
|
101
|
+
old_value_rms: ExponentialMovingNormalizer = self.agent.critic.value_rms
|
|
102
|
+
old_mean, old_std = old_value_rms.mean, old_value_rms.std
|
|
103
|
+
# Adjust value head weights and biases
|
|
104
|
+
new_mean, new_std = self.value_rms.mean, self.value_rms.std
|
|
105
|
+
value_head = self.agent.critic.value_head
|
|
106
|
+
with torch.no_grad():
|
|
107
|
+
value_head.weight.data.mul_(old_std / new_std)
|
|
108
|
+
value_head.bias.data.mul_(old_std).add_(old_mean).sub_(new_mean).div_(new_std)
|
|
109
|
+
old_value_rms.load_state_dict(self.value_rms.state_dict())
|
|
110
|
+
|
|
111
|
+
@torch.no_grad()
|
|
112
|
+
def _compute_advantage_and_return(self, data):
|
|
113
|
+
value = data["value"]
|
|
114
|
+
next_value = data["next_value"]
|
|
115
|
+
if (value_rms := self.agent.critic.value_rms) is not None:
|
|
116
|
+
value = value_rms.unnormalize(value)
|
|
117
|
+
next_value = value_rms.unnormalize(next_value)
|
|
118
|
+
|
|
119
|
+
data["advantage"] = _generalized_advantage_estimation(
|
|
120
|
+
reward=data["reward"],
|
|
121
|
+
done=data["done"],
|
|
122
|
+
value=value,
|
|
123
|
+
next_value=next_value,
|
|
124
|
+
gamma=self.gamma,
|
|
125
|
+
lamda=self.lamda,
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
data["return"] = value + (
|
|
129
|
+
data["advantage"]
|
|
130
|
+
if self.lamda_value is None
|
|
131
|
+
else _generalized_advantage_estimation(
|
|
132
|
+
reward=data["reward"],
|
|
133
|
+
done=data["done"],
|
|
134
|
+
value=value,
|
|
135
|
+
next_value=next_value,
|
|
136
|
+
gamma=self.gamma,
|
|
137
|
+
lamda=self.lamda_value,
|
|
138
|
+
)
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
if value_rms is not None:
|
|
142
|
+
self.value_rms.update(data["return"])
|
|
143
|
+
value_rms.normalize_(data["return"])
|
cusrl/hook/gradient.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
from torch import nn
|
|
2
|
+
|
|
3
|
+
from cusrl.template import Hook
|
|
4
|
+
|
|
5
|
+
__all__ = ["GradientClipping"]
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class GradientClipping(Hook):
|
|
9
|
+
"""A hook to clip gradients of model parameters before the optimizer step,
|
|
10
|
+
grouping parameters by name prefixes to handle varying gradient scales
|
|
11
|
+
across modules.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
max_grad_norm (float | None, optional):
|
|
15
|
+
Default max norm for gradient clipping. If None, no clipping is
|
|
16
|
+
applied for the default group. Defaults to 1.0.
|
|
17
|
+
**groups (dict[str, float | None]):
|
|
18
|
+
Keyword arguments mapping parameter name prefixes to specific max
|
|
19
|
+
gradient norms.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(self, max_grad_norm: float | None = 1.0, **groups: float | None):
|
|
23
|
+
groups[""] = max_grad_norm
|
|
24
|
+
for prefix, max_grad_norm in groups.items():
|
|
25
|
+
if max_grad_norm is not None and max_grad_norm < 0:
|
|
26
|
+
raise ValueError(f"'max_grad_norm' for prefix '{prefix}' must be positive.")
|
|
27
|
+
# Sort by length of prefix (longest first for more specific matching)
|
|
28
|
+
self.groups = dict(sorted(groups.items(), key=lambda x: len(x[0]), reverse=True))
|
|
29
|
+
|
|
30
|
+
def pre_optim(self, optimizer):
|
|
31
|
+
prefixed_parameters = {prefix: [] for prefix in self.groups}
|
|
32
|
+
for param_group in optimizer.param_groups:
|
|
33
|
+
params = param_group["params"]
|
|
34
|
+
param_names = param_group.get("param_names", [""] * len(params))
|
|
35
|
+
for param, name in zip(params, param_names, strict=True):
|
|
36
|
+
prefix = self._match_prefix(name)
|
|
37
|
+
prefixed_parameters[prefix].append(param)
|
|
38
|
+
# Clip gradients for each group
|
|
39
|
+
for prefix, params in prefixed_parameters.items():
|
|
40
|
+
if params and (max_grad_norm := self.groups[prefix]) is not None:
|
|
41
|
+
nn.utils.clip_grad_norm_(params, max_grad_norm)
|
|
42
|
+
|
|
43
|
+
def _match_prefix(self, name):
|
|
44
|
+
# Find the longest matching prefix (most specific)
|
|
45
|
+
for prefix in self.groups:
|
|
46
|
+
if name == prefix or name.startswith(f"{prefix}."):
|
|
47
|
+
return prefix
|
|
48
|
+
return ""
|
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
import itertools
|
|
2
|
+
import math
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
5
|
+
from torch import nn
|
|
6
|
+
|
|
7
|
+
from cusrl.template import ActorCritic, Hook
|
|
8
|
+
|
|
9
|
+
__all__ = ["ModuleInitialization"]
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ModuleInitialization(Hook[ActorCritic]):
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
scale: float = math.sqrt(2),
|
|
16
|
+
scale_dist: float = math.sqrt(2) * 0.1,
|
|
17
|
+
zero_bias: bool = True,
|
|
18
|
+
conv_a: float = 0.0,
|
|
19
|
+
conv_mode: Literal["fan_in", "fan_out"] = "fan_in",
|
|
20
|
+
conv_nonlinearity: Literal["relu", "leaky_relu"] = "leaky_relu",
|
|
21
|
+
init_actor: bool = True,
|
|
22
|
+
init_critic: bool = True,
|
|
23
|
+
distribution_std: float | None = None,
|
|
24
|
+
):
|
|
25
|
+
self.scale = scale
|
|
26
|
+
self.scale_dist = scale_dist
|
|
27
|
+
self.zero_bias = zero_bias
|
|
28
|
+
self.conv_a = conv_a
|
|
29
|
+
self.conv_mode = conv_mode
|
|
30
|
+
self.conv_nonlinearity = conv_nonlinearity
|
|
31
|
+
self.init_actor = init_actor
|
|
32
|
+
self.init_critic = init_critic
|
|
33
|
+
self.distribution_std = distribution_std
|
|
34
|
+
|
|
35
|
+
def init(self):
|
|
36
|
+
if self.init_actor:
|
|
37
|
+
for module in itertools.chain(self.agent.actor.modules()):
|
|
38
|
+
self._init_module(module, self.scale, self.zero_bias)
|
|
39
|
+
if self.scale_dist != self.scale:
|
|
40
|
+
self._init_linear(self.agent.actor.distribution.mean_head, self.scale_dist, self.zero_bias)
|
|
41
|
+
if self.distribution_std is not None:
|
|
42
|
+
self.agent.actor.set_distribution_std(self.distribution_std)
|
|
43
|
+
if self.init_critic:
|
|
44
|
+
for module in itertools.chain(self.agent.critic.modules()):
|
|
45
|
+
self._init_module(module, self.scale, self.zero_bias)
|
|
46
|
+
|
|
47
|
+
def _init_module(self, module: nn.Module, scale: float, zero_bias: bool):
|
|
48
|
+
if isinstance(module, nn.Linear):
|
|
49
|
+
self._init_linear(module, scale, zero_bias)
|
|
50
|
+
elif isinstance(module, (nn.LSTM, nn.GRU)):
|
|
51
|
+
self._init_gru_lstm(module, scale, zero_bias)
|
|
52
|
+
elif isinstance(module, nn.MultiheadAttention):
|
|
53
|
+
self._init_mha(module, scale, zero_bias)
|
|
54
|
+
elif isinstance(module, nn.Conv2d):
|
|
55
|
+
self._init_conv2d(module, zero_bias)
|
|
56
|
+
|
|
57
|
+
def _init_linear(self, module: nn.Linear, scale: float, zero_bias: bool):
|
|
58
|
+
nn.init.orthogonal_(module.weight, gain=scale)
|
|
59
|
+
if zero_bias and module.bias is not None:
|
|
60
|
+
nn.init.zeros_(module.bias)
|
|
61
|
+
|
|
62
|
+
def _init_gru_lstm(self, module: nn.GRU | nn.LSTM, scale: float, zero_bias: bool):
|
|
63
|
+
for i in range(module.num_layers):
|
|
64
|
+
nn.init.orthogonal_(getattr(module, f"weight_hh_l{i}"), gain=scale)
|
|
65
|
+
nn.init.orthogonal_(getattr(module, f"weight_ih_l{i}"), gain=scale)
|
|
66
|
+
if zero_bias and getattr(module, f"bias_hh_l{i}") is not None:
|
|
67
|
+
nn.init.zeros_(getattr(module, f"bias_hh_l{i}"))
|
|
68
|
+
nn.init.zeros_(getattr(module, f"bias_ih_l{i}"))
|
|
69
|
+
|
|
70
|
+
def _init_mha(self, module: nn.MultiheadAttention, scale: float, zero_bias: bool):
|
|
71
|
+
if module.in_proj_weight is not None:
|
|
72
|
+
nn.init.orthogonal_(module.in_proj_weight, gain=scale)
|
|
73
|
+
else:
|
|
74
|
+
nn.init.orthogonal_(module.q_proj_weight, gain=scale)
|
|
75
|
+
nn.init.orthogonal_(module.k_proj_weight, gain=scale)
|
|
76
|
+
nn.init.orthogonal_(module.v_proj_weight, gain=scale)
|
|
77
|
+
|
|
78
|
+
if zero_bias:
|
|
79
|
+
if module.in_proj_bias is not None:
|
|
80
|
+
nn.init.zeros_(module.in_proj_bias)
|
|
81
|
+
if module.bias_k is not None:
|
|
82
|
+
nn.init.zeros_(module.bias_k)
|
|
83
|
+
if module.bias_v is not None:
|
|
84
|
+
nn.init.zeros_(module.bias_v)
|
|
85
|
+
|
|
86
|
+
def _init_conv2d(self, module: nn.Conv2d, zero_bias: bool):
|
|
87
|
+
nn.init.kaiming_normal_(
|
|
88
|
+
module.weight,
|
|
89
|
+
a=self.conv_a,
|
|
90
|
+
mode=self.conv_mode,
|
|
91
|
+
nonlinearity=self.conv_nonlinearity,
|
|
92
|
+
)
|
|
93
|
+
if zero_bias and module.bias is not None:
|
|
94
|
+
nn.init.zeros_(module.bias)
|
|
@@ -0,0 +1,178 @@
|
|
|
1
|
+
import math
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from cusrl.template import ActorCritic, Hook
|
|
6
|
+
from cusrl.utils import distributed
|
|
7
|
+
|
|
8
|
+
__all__ = ["AdaptiveLRSchedule", "MiniBatchWiseLRSchedule", "ThresholdLRSchedule"]
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class KLDivergenceBasedLRSchedule(Hook[ActorCritic]):
|
|
12
|
+
MUTABLE_ATTRS = ["desired_kl_divergence"]
|
|
13
|
+
|
|
14
|
+
def __init__(
|
|
15
|
+
self,
|
|
16
|
+
desired_kl_divergence: float = 0.01,
|
|
17
|
+
scale_all_params: bool = False,
|
|
18
|
+
):
|
|
19
|
+
if desired_kl_divergence <= 0:
|
|
20
|
+
raise ValueError("'desired_kl_divergence' must be positive.")
|
|
21
|
+
|
|
22
|
+
self.desired_kl_divergence = desired_kl_divergence
|
|
23
|
+
self.scale_all_params = scale_all_params
|
|
24
|
+
self._lr_scale = 1.0
|
|
25
|
+
|
|
26
|
+
def post_update(self):
|
|
27
|
+
kl_divergence = self.agent.metrics["kl_divergence"].mean.clone()
|
|
28
|
+
distributed.reduce_mean_(kl_divergence)
|
|
29
|
+
scale = self._compute_scale(kl_divergence.item())
|
|
30
|
+
self._scale_lr_of_parameters(scale)
|
|
31
|
+
self.agent.record(lr_scale=self._lr_scale)
|
|
32
|
+
|
|
33
|
+
def _compute_scale(self, kl_divergence: float) -> float | None:
|
|
34
|
+
raise NotImplementedError
|
|
35
|
+
|
|
36
|
+
def _scale_lr_of_parameters(self, scale: float | None):
|
|
37
|
+
if scale is None or scale == 1.0:
|
|
38
|
+
return
|
|
39
|
+
|
|
40
|
+
self._lr_scale *= scale
|
|
41
|
+
for param_group in self.agent.optimizer.param_groups:
|
|
42
|
+
if self.scale_all_params or any(name.startswith("actor.") for name in param_group["param_names"]):
|
|
43
|
+
param_group["lr"] *= scale
|
|
44
|
+
|
|
45
|
+
def state_dict(self):
|
|
46
|
+
return {"lr_scale": self._lr_scale}
|
|
47
|
+
|
|
48
|
+
def load_state_dict(self, state_dict):
|
|
49
|
+
self._lr_scale = state_dict["lr_scale"]
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class ThresholdLRSchedule(KLDivergenceBasedLRSchedule):
|
|
53
|
+
"""Adjusts the learning rate based on thresholded KL divergence.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
desired_kl_divergence (float, optional):
|
|
57
|
+
Target KL divergence to maintain. Defaults to 0.01.
|
|
58
|
+
threshold (float, optional):
|
|
59
|
+
Ratio threshold (>1) for deciding when to adjust. Defaults to 1.2.
|
|
60
|
+
scale_factor (float, optional):
|
|
61
|
+
Multiplicative factor (>1) for scaling the LR. Defaults to 1.1.
|
|
62
|
+
scale_all_params (bool, optional):
|
|
63
|
+
If True, scales all optimizer parameter groups; otherwise only
|
|
64
|
+
scales actor parameter groups. Defaults to False.
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
MUTABLE_ATTRS = ["desired_kl_divergence"]
|
|
68
|
+
|
|
69
|
+
def __init__(
|
|
70
|
+
self,
|
|
71
|
+
desired_kl_divergence: float = 0.01,
|
|
72
|
+
threshold: float = 1.2,
|
|
73
|
+
scale_factor: float = 1.1,
|
|
74
|
+
scale_all_params: bool = False,
|
|
75
|
+
):
|
|
76
|
+
super().__init__(desired_kl_divergence, scale_all_params)
|
|
77
|
+
if threshold <= 1:
|
|
78
|
+
raise ValueError("'threshold' must be greater than 1.")
|
|
79
|
+
if scale_factor <= 1:
|
|
80
|
+
raise ValueError("'scale_factor' must be greater than 1.")
|
|
81
|
+
|
|
82
|
+
self.threshold = threshold
|
|
83
|
+
self.scale_factor = scale_factor
|
|
84
|
+
|
|
85
|
+
def _compute_scale(self, kl_divergence: float):
|
|
86
|
+
if kl_divergence > self.desired_kl_divergence * self.threshold:
|
|
87
|
+
return 1 / self.scale_factor
|
|
88
|
+
if kl_divergence < self.desired_kl_divergence / self.threshold:
|
|
89
|
+
return self.scale_factor
|
|
90
|
+
return None
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class AdaptiveLRSchedule(KLDivergenceBasedLRSchedule):
|
|
94
|
+
"""Adaptively adjusts the learning rate based on accumulated KL divergence error.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
desired_kl_divergence (float, optional):
|
|
98
|
+
Target KL divergence to maintain. Defaults to 0.01.
|
|
99
|
+
threshold (float, optional):
|
|
100
|
+
Positive threshold for accumulated log-error before scaling. Defaults to 1.0.
|
|
101
|
+
scale_factor (float, optional):
|
|
102
|
+
Positive coefficient controlling adjustment magnitude. Defaults to 0.2.
|
|
103
|
+
scale_all_params (bool, optional):
|
|
104
|
+
If True, scales all optimizer parameter groups; otherwise only scales actor
|
|
105
|
+
parameter groups. Defaults to False.
|
|
106
|
+
"""
|
|
107
|
+
|
|
108
|
+
MUTABLE_ATTRS = ["desired_kl_divergence"]
|
|
109
|
+
|
|
110
|
+
def __init__(
|
|
111
|
+
self,
|
|
112
|
+
desired_kl_divergence: float = 0.01,
|
|
113
|
+
threshold: float = 1.0,
|
|
114
|
+
scale_factor: float = 0.2,
|
|
115
|
+
scale_all_params: bool = False,
|
|
116
|
+
):
|
|
117
|
+
super().__init__(desired_kl_divergence, scale_all_params)
|
|
118
|
+
if threshold <= 0:
|
|
119
|
+
raise ValueError("'threshold' must be positive.")
|
|
120
|
+
if scale_factor <= 0:
|
|
121
|
+
raise ValueError("'scale_factor' must be positive.")
|
|
122
|
+
|
|
123
|
+
self.threshold = threshold
|
|
124
|
+
self.scale_factor = scale_factor
|
|
125
|
+
self.accumulated_log_error = 0.0
|
|
126
|
+
self.count = 0
|
|
127
|
+
|
|
128
|
+
def _compute_scale(self, kl_divergence: float):
|
|
129
|
+
self.accumulated_log_error += math.log(kl_divergence / self.desired_kl_divergence)
|
|
130
|
+
self.count += 1
|
|
131
|
+
if self.threshold > self.accumulated_log_error > -self.threshold:
|
|
132
|
+
return None
|
|
133
|
+
average_log_error = self.accumulated_log_error / self.count
|
|
134
|
+
scale = math.exp(-min(max(average_log_error, -1.0), 1.0) * self.scale_factor)
|
|
135
|
+
self.accumulated_log_error = 0.0
|
|
136
|
+
self.count = 0
|
|
137
|
+
return scale
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
class MiniBatchWiseLRSchedule(ThresholdLRSchedule):
|
|
141
|
+
"""Applies a threshold-based LR schedule on a per-mini-batch KL divergence.
|
|
142
|
+
Modified from (RSL-RL)[https://github.com/leggedrobotics/rsl_rl].
|
|
143
|
+
|
|
144
|
+
Args:
|
|
145
|
+
desired_kl_divergence (float, optional):
|
|
146
|
+
Target KL divergence per mini-batch. Defaults to 0.01.
|
|
147
|
+
threshold (float, optional):
|
|
148
|
+
Ratio threshold for deciding scaling per batch. Defaults to 2.0.
|
|
149
|
+
scale_factor (float, optional):
|
|
150
|
+
Multiplicative factor for scaling the LR. Defaults to 1.5.
|
|
151
|
+
"""
|
|
152
|
+
|
|
153
|
+
def __init__(
|
|
154
|
+
self,
|
|
155
|
+
desired_kl_divergence: float = 0.01,
|
|
156
|
+
threshold: float = 2.0,
|
|
157
|
+
scale_factor: float = 1.5,
|
|
158
|
+
):
|
|
159
|
+
super().__init__(
|
|
160
|
+
desired_kl_divergence,
|
|
161
|
+
threshold=threshold,
|
|
162
|
+
scale_factor=scale_factor,
|
|
163
|
+
scale_all_params=True,
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
def post_init(self):
|
|
167
|
+
self.agent.hook["OnPolicyPreparation"].calculate_kl_divergence = True
|
|
168
|
+
|
|
169
|
+
def post_update(self):
|
|
170
|
+
pass
|
|
171
|
+
|
|
172
|
+
def objective(self, batch):
|
|
173
|
+
with torch.inference_mode():
|
|
174
|
+
kl_divergence = batch["kl_divergence"].mean()
|
|
175
|
+
distributed.reduce_mean_(kl_divergence)
|
|
176
|
+
scale = self._compute_scale(kl_divergence.item())
|
|
177
|
+
self._scale_lr_of_parameters(scale)
|
|
178
|
+
self.agent.record(lr_scale=self._lr_scale)
|
|
@@ -0,0 +1,194 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from cusrl.hook.symmetry import SymmetryDef
|
|
8
|
+
from cusrl.template import ActorCritic, Hook
|
|
9
|
+
from cusrl.utils import RunningMeanStd, mean_var_count
|
|
10
|
+
from cusrl.utils.export import ExportSpec
|
|
11
|
+
from cusrl.utils.typing import Slice
|
|
12
|
+
|
|
13
|
+
__all__ = ["ObservationNormalization"]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class ObservationNormalization(Hook[ActorCritic]):
|
|
17
|
+
"""Normalizes observations and states using a running mean and standard deviation.
|
|
18
|
+
|
|
19
|
+
This hook maintains a running estimate of the mean and standard deviation for
|
|
20
|
+
observations and, if available, states. It intercepts transitions during data
|
|
21
|
+
collection to normalize the `observation`, `state`, `next_observation`, and
|
|
22
|
+
`next_state` fields. The original, unnormalized values are preserved under
|
|
23
|
+
keys with an "original_" prefix.
|
|
24
|
+
|
|
25
|
+
The running statistics are updated with new data from each step, unless the
|
|
26
|
+
agent is in inference mode or the `frozen` attribute is set to `True`.
|
|
27
|
+
|
|
28
|
+
The hook also handles scenarios where the observation is a subset of the
|
|
29
|
+
state or where there is symmetry in the observations or states. It correctly
|
|
30
|
+
synchronizes statistics across distributed processes. During model export,
|
|
31
|
+
it attaches forward pre-hooks to the actor and critic models to ensure that
|
|
32
|
+
inputs are automatically normalized.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
max_count:
|
|
36
|
+
The maximum count for the running statistics to prevent numerical
|
|
37
|
+
overflow. Defaults to None.
|
|
38
|
+
defer_synchronization:
|
|
39
|
+
If True, synchronization of running statistics in a distributed setting
|
|
40
|
+
is deferred until the end of a rollout. This can improve performance
|
|
41
|
+
by reducing the frequency of synchronization. Defaults to False.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
observation_rms: RunningMeanStd
|
|
45
|
+
state_rms: RunningMeanStd | None = None
|
|
46
|
+
_mirror_observation: SymmetryDef | None = None
|
|
47
|
+
_mirror_state: SymmetryDef | None = None
|
|
48
|
+
_observation_is_subset_of_state: Slice | torch.Tensor | None = None
|
|
49
|
+
|
|
50
|
+
MODULES = ["observation_rms", "state_rms"]
|
|
51
|
+
MUTABLE_ATTRS = ["frozen"]
|
|
52
|
+
|
|
53
|
+
def __init__(self, max_count: int | None = None, defer_synchronization: bool = False):
|
|
54
|
+
if max_count is not None and max_count <= 0:
|
|
55
|
+
raise ValueError("'max_count' must be positive or None.")
|
|
56
|
+
self.max_count = max_count
|
|
57
|
+
self.frozen: bool = False
|
|
58
|
+
self.defer_synchronization = defer_synchronization
|
|
59
|
+
self._last_done: torch.Tensor | None = None
|
|
60
|
+
|
|
61
|
+
def init(self):
|
|
62
|
+
# Retrieve and normalize the subset index spec
|
|
63
|
+
env_spec = self.agent.environment_spec
|
|
64
|
+
observation_is_subset_of_state = env_spec.observation_is_subset_of_state
|
|
65
|
+
if observation_is_subset_of_state is not None:
|
|
66
|
+
if not self.agent.has_state:
|
|
67
|
+
raise ValueError("'observation_is_subset_of_state' is set but state is not defined.")
|
|
68
|
+
# Convert numpy or list indices to a tensor for consistent indexing
|
|
69
|
+
if isinstance(observation_is_subset_of_state, (np.ndarray, Sequence)):
|
|
70
|
+
observation_is_subset_of_state = self.agent.to_tensor(np.asarray(observation_is_subset_of_state))
|
|
71
|
+
self._observation_is_subset_of_state = observation_is_subset_of_state
|
|
72
|
+
|
|
73
|
+
if self._observation_is_subset_of_state is not None:
|
|
74
|
+
self.observation_rms = self.__make_rms(self.agent.observation_dim)
|
|
75
|
+
else:
|
|
76
|
+
self.observation_rms = self.__make_rms(
|
|
77
|
+
self.agent.observation_dim, self.max_count, env_spec.observation_stat_groups
|
|
78
|
+
)
|
|
79
|
+
if self.agent.has_state:
|
|
80
|
+
self.state_rms = self.__make_rms(self.agent.state_dim, self.max_count, env_spec.state_stat_groups)
|
|
81
|
+
self._mirror_observation = env_spec.mirror_observation
|
|
82
|
+
self._mirror_state = env_spec.mirror_state
|
|
83
|
+
|
|
84
|
+
def pre_act(self, transition: dict):
|
|
85
|
+
observation, state = transition["observation"], transition.get("state")
|
|
86
|
+
if self._last_done is None or not self.agent.environment_spec.final_state_is_missing:
|
|
87
|
+
self.__update_rms(observation, state, self._last_done)
|
|
88
|
+
|
|
89
|
+
transition["original_observation"] = observation
|
|
90
|
+
transition["observation"] = self.observation_rms.normalize(observation)
|
|
91
|
+
if state is not None:
|
|
92
|
+
transition["original_state"] = state
|
|
93
|
+
transition["state"] = self.state_rms.normalize(state)
|
|
94
|
+
|
|
95
|
+
def post_step(self, transition: dict):
|
|
96
|
+
next_observation, next_state = transition["next_observation"], transition.get("next_state")
|
|
97
|
+
self.__update_rms(next_observation, next_state)
|
|
98
|
+
self._last_done = transition["done"].squeeze(-1)
|
|
99
|
+
|
|
100
|
+
transition["original_next_observation"] = next_observation
|
|
101
|
+
transition["next_observation"] = self.observation_rms.normalize(next_observation)
|
|
102
|
+
if next_state is not None:
|
|
103
|
+
transition["original_next_state"] = next_state
|
|
104
|
+
transition["next_state"] = self.state_rms.normalize(next_state)
|
|
105
|
+
|
|
106
|
+
def __make_rms(
|
|
107
|
+
self,
|
|
108
|
+
num_channels: int,
|
|
109
|
+
max_count: int | None = None,
|
|
110
|
+
stat_groups: tuple[tuple[int, int], ...] = (),
|
|
111
|
+
):
|
|
112
|
+
normalizer = RunningMeanStd(num_channels, max_count=max_count).to(self.agent.device)
|
|
113
|
+
for group in stat_groups:
|
|
114
|
+
normalizer.register_stat_group(*group)
|
|
115
|
+
return normalizer
|
|
116
|
+
|
|
117
|
+
def __update_rms(
|
|
118
|
+
self,
|
|
119
|
+
observation: torch.Tensor,
|
|
120
|
+
state: torch.Tensor | None,
|
|
121
|
+
indices: torch.Tensor | None = None,
|
|
122
|
+
):
|
|
123
|
+
if self.agent.inference_mode or self.frozen:
|
|
124
|
+
# Do not update the statistics during inference or if frozen
|
|
125
|
+
return
|
|
126
|
+
|
|
127
|
+
if state is not None:
|
|
128
|
+
self.__update_rms_impl(state, self.state_rms, self._mirror_state, indices)
|
|
129
|
+
if self._observation_is_subset_of_state is not None:
|
|
130
|
+
self.observation_rms.mean.copy_(self.state_rms.mean[self._observation_is_subset_of_state])
|
|
131
|
+
self.observation_rms.var.copy_(self.state_rms.var[self._observation_is_subset_of_state])
|
|
132
|
+
self.observation_rms.std.copy_(self.state_rms.std[self._observation_is_subset_of_state])
|
|
133
|
+
self.observation_rms.count = self.state_rms.count
|
|
134
|
+
else:
|
|
135
|
+
self.__update_rms_impl(observation, self.observation_rms, self._mirror_observation, indices)
|
|
136
|
+
|
|
137
|
+
def __update_rms_impl(
|
|
138
|
+
self,
|
|
139
|
+
observation: torch.Tensor,
|
|
140
|
+
rms: RunningMeanStd,
|
|
141
|
+
mirror: SymmetryDef | None = None,
|
|
142
|
+
indices: torch.Tensor | None = None,
|
|
143
|
+
):
|
|
144
|
+
if indices is not None:
|
|
145
|
+
observation = observation[indices]
|
|
146
|
+
mean, var, count = mean_var_count(observation)
|
|
147
|
+
if mirror is not None:
|
|
148
|
+
mirrored_mean = mirror(mean)
|
|
149
|
+
mirrored_var = abs(mirror(var))
|
|
150
|
+
var = (var + mirrored_var) / 2 + (mean - mirrored_mean) ** 2 / 4
|
|
151
|
+
mean = (mean + mirrored_mean) / 2
|
|
152
|
+
rms.update_from_stats(mean, var, count, synchronize=not self.defer_synchronization)
|
|
153
|
+
|
|
154
|
+
def pre_update(self, buffer):
|
|
155
|
+
if self.defer_synchronization:
|
|
156
|
+
if self.state_rms is not None:
|
|
157
|
+
self.state_rms.synchronize()
|
|
158
|
+
if self._observation_is_subset_of_state is not None:
|
|
159
|
+
self.observation_rms.mean.copy_(self.state_rms.mean[self._observation_is_subset_of_state])
|
|
160
|
+
self.observation_rms.var.copy_(self.state_rms.var[self._observation_is_subset_of_state])
|
|
161
|
+
self.observation_rms.std.copy_(self.state_rms.std[self._observation_is_subset_of_state])
|
|
162
|
+
self.observation_rms.count = self.state_rms.count
|
|
163
|
+
else:
|
|
164
|
+
self.observation_rms.synchronize()
|
|
165
|
+
|
|
166
|
+
def export(self, export_data: dict[str, ExportSpec]):
|
|
167
|
+
export_data["actor"].module.register_forward_pre_hook(
|
|
168
|
+
self.__normalize_observation_forward_pre_hook, with_kwargs=True
|
|
169
|
+
)
|
|
170
|
+
if "critic" in export_data:
|
|
171
|
+
export_data["critic"].module.register_forward_pre_hook(
|
|
172
|
+
(
|
|
173
|
+
self.__normalize_state_forward_pre_hook
|
|
174
|
+
if self.agent.has_state
|
|
175
|
+
else self.__normalize_observation_forward_pre_hook
|
|
176
|
+
),
|
|
177
|
+
with_kwargs=True,
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
@torch.no_grad()
|
|
181
|
+
def __normalize_observation_forward_pre_hook(self, module, args: tuple, kwargs: dict[str, Any]):
|
|
182
|
+
if "observation" in kwargs:
|
|
183
|
+
kwargs["observation"] = self.observation_rms.normalize(kwargs["observation"])
|
|
184
|
+
return args, kwargs
|
|
185
|
+
normalized_observation = self.observation_rms.normalize(args[0])
|
|
186
|
+
return (normalized_observation, *args[1:]), kwargs
|
|
187
|
+
|
|
188
|
+
@torch.no_grad()
|
|
189
|
+
def __normalize_state_forward_pre_hook(self, module, args: tuple, kwargs: dict[str, Any]):
|
|
190
|
+
if "state" in kwargs:
|
|
191
|
+
kwargs["state"] = self.state_rms.normalize(kwargs["state"])
|
|
192
|
+
return args, kwargs
|
|
193
|
+
normalized_state = self.state_rms.normalize(args[0])
|
|
194
|
+
return (normalized_state, *args[1:]), kwargs
|