cusrl 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. cusrl/__init__.py +107 -0
  2. cusrl/environment/__init__.py +11 -0
  3. cusrl/environment/gym.py +157 -0
  4. cusrl/environment/isaaclab.py +133 -0
  5. cusrl/hook/__init__.py +58 -0
  6. cusrl/hook/advantage.py +100 -0
  7. cusrl/hook/condition.py +57 -0
  8. cusrl/hook/gae.py +143 -0
  9. cusrl/hook/gradient.py +48 -0
  10. cusrl/hook/initialization.py +94 -0
  11. cusrl/hook/lr_schedule.py +178 -0
  12. cusrl/hook/normalization.py +194 -0
  13. cusrl/hook/on_policy.py +35 -0
  14. cusrl/hook/ppo.py +77 -0
  15. cusrl/hook/representation.py +132 -0
  16. cusrl/hook/rnd.py +66 -0
  17. cusrl/hook/schedule.py +114 -0
  18. cusrl/hook/smoothness.py +75 -0
  19. cusrl/hook/statistics.py +28 -0
  20. cusrl/hook/symmetry.py +233 -0
  21. cusrl/hook/value.py +158 -0
  22. cusrl/launch/export.py +43 -0
  23. cusrl/launch/play.py +45 -0
  24. cusrl/launch/train.py +62 -0
  25. cusrl/logger/__init__.py +5 -0
  26. cusrl/logger/make_factory.py +18 -0
  27. cusrl/logger/tensorboard_logger.py +28 -0
  28. cusrl/logger/wandb_logger.py +68 -0
  29. cusrl/module/__init__.py +39 -0
  30. cusrl/module/actor.py +203 -0
  31. cusrl/module/attention.py +614 -0
  32. cusrl/module/bijector.py +115 -0
  33. cusrl/module/cnn.py +75 -0
  34. cusrl/module/critic.py +73 -0
  35. cusrl/module/distribution.py +263 -0
  36. cusrl/module/inference.py +57 -0
  37. cusrl/module/mlp.py +63 -0
  38. cusrl/module/module.py +182 -0
  39. cusrl/module/normalization.py +59 -0
  40. cusrl/module/rnn.py +167 -0
  41. cusrl/module/sequential.py +70 -0
  42. cusrl/module/simba.py +70 -0
  43. cusrl/preset/__init__.py +5 -0
  44. cusrl/preset/ppo.py +216 -0
  45. cusrl/sampler/__init__.py +11 -0
  46. cusrl/sampler/mini_batch_sampler.py +78 -0
  47. cusrl/template/__init__.py +27 -0
  48. cusrl/template/actor_critic.py +321 -0
  49. cusrl/template/agent.py +259 -0
  50. cusrl/template/buffer.py +271 -0
  51. cusrl/template/environment.py +208 -0
  52. cusrl/template/hook.py +244 -0
  53. cusrl/template/logger.py +76 -0
  54. cusrl/template/optimizer.py +68 -0
  55. cusrl/template/player.py +114 -0
  56. cusrl/template/trainer.py +290 -0
  57. cusrl/template/trial.py +103 -0
  58. cusrl/utils/__init__.py +30 -0
  59. cusrl/utils/cli.py +59 -0
  60. cusrl/utils/config.py +75 -0
  61. cusrl/utils/distributed.py +146 -0
  62. cusrl/utils/export.py +98 -0
  63. cusrl/utils/helper.py +122 -0
  64. cusrl/utils/metrics.py +72 -0
  65. cusrl/utils/nest.py +82 -0
  66. cusrl/utils/normalizer.py +276 -0
  67. cusrl/utils/recurrent.py +163 -0
  68. cusrl/utils/timing.py +63 -0
  69. cusrl/utils/typing.py +45 -0
  70. cusrl/utils/video.py +21 -0
  71. cusrl/zoo/__init__.py +8 -0
  72. cusrl/zoo/experiment.py +105 -0
  73. cusrl/zoo/gym/__init__.py +2 -0
  74. cusrl/zoo/gym/box2d.py +63 -0
  75. cusrl/zoo/gym/classic_control.py +142 -0
  76. cusrl/zoo/isaaclab/__init__.py +2 -0
  77. cusrl/zoo/isaaclab/classic.py +69 -0
  78. cusrl/zoo/isaaclab/locomotion.py +93 -0
  79. cusrl/zoo/registry.py +70 -0
  80. cusrl-1.0.0.dist-info/METADATA +109 -0
  81. cusrl-1.0.0.dist-info/RECORD +83 -0
  82. cusrl-1.0.0.dist-info/WHEEL +5 -0
  83. cusrl-1.0.0.dist-info/top_level.txt +1 -0
cusrl/hook/gae.py ADDED
@@ -0,0 +1,143 @@
1
+ from typing import Any
2
+
3
+ import torch
4
+
5
+ from cusrl.template import ActorCritic, Hook
6
+ from cusrl.utils import ExponentialMovingNormalizer
7
+
8
+ __all__ = ["GAE"]
9
+
10
+
11
+ @torch.jit.script
12
+ def _generalized_advantage_estimation(
13
+ reward: torch.Tensor,
14
+ done: torch.Tensor,
15
+ value: torch.Tensor,
16
+ next_value: torch.Tensor,
17
+ gamma: float,
18
+ lamda: float,
19
+ ) -> torch.Tensor:
20
+ not_done = done.logical_not()
21
+ advantage = reward + next_value * gamma - value
22
+ for step in range(advantage.size(0) - 2, -1, -1):
23
+ advantage[step] += not_done[step] * (gamma * lamda) * advantage[step + 1]
24
+ return advantage
25
+
26
+
27
+ class GAE(Hook[ActorCritic]):
28
+ """A hook that computes advantages and returns using Generalized Advantage Estimation (GAE).
29
+
30
+ GAE is described in:
31
+ "High-Dimensional Continuous Control Using Generalized Advantage Estimation",
32
+ https://arxiv.org/abs/1506.02438
33
+
34
+ Distinct lambda values can be enabled to individually control the bias-variance trade-offs
35
+ for policy and value function, described in:
36
+ "DNA: Proximal Policy Optimization with a Dual Network Architecture"
37
+ https://proceedings.neurips.cc/paper_files/paper/2022/hash/e95475f5fb8edb9075bf9e25670d4013-Abstract-Conference.html
38
+
39
+ PopArt normalization can be applied to the value function, described in:
40
+ "Learning values across many orders of magnitude",
41
+ https://proceedings.neurips.cc/paper/2016/hash/5227b6aaf294f5f027273aebf16015f2-Abstract.html
42
+
43
+ Args:
44
+ gamma (float, optional):
45
+ Discount factor for future rewards, in [0, 1). Defaults to 0.99.
46
+ lamda (float, optional):
47
+ Smoothing factor for advantage estimation, in [0, 1]. Defaults to 0.95.
48
+ lamda_value (float | None, optional):
49
+ Smoothing factor for value function calculation, in [0, 1].
50
+ If None, the same value as `lamda` is used. Defaults to None.
51
+ recompute (bool, optional):
52
+ If True, recompute advantages and returns after each update. Defaults to False.
53
+ popart_alpha (float | None, optional):
54
+ If not None, applies PopArt normalization to the value function with the specified
55
+ alpha. Defaults to None, which means no normalization is applied.
56
+ """
57
+
58
+ MODULES = ["value_rms"]
59
+ MUTABLE_ATTRS = ["gamma", "lamda"]
60
+
61
+ def __init__(
62
+ self,
63
+ gamma: float = 0.99,
64
+ lamda: float = 0.95,
65
+ lamda_value: float | None = None,
66
+ recompute: bool = False,
67
+ popart_alpha: float | None = None,
68
+ ):
69
+ if gamma < 0 or gamma >= 1:
70
+ raise ValueError(f"Invalid gamma value {gamma}, which should be in [0, 1).")
71
+ if lamda < 0 or lamda > 1:
72
+ raise ValueError(f"Invalid lambda value {lamda}, which should be in [0, 1].")
73
+ if lamda_value is not None and (lamda_value < 0 or lamda_value > 1):
74
+ raise ValueError(f"Invalid lambda value for value function {lamda_value}, which should be in [0, 1].")
75
+
76
+ self.gamma = gamma
77
+ self.lamda = lamda
78
+ self.lamda_value = lamda_value
79
+ self.recompute = recompute
80
+ self.popart_alpha = popart_alpha
81
+ self.value_rms: ExponentialMovingNormalizer | None = None
82
+
83
+ def init(self):
84
+ if self.popart_alpha is not None:
85
+ self.value_rms = self.__make_normalizer(self.agent.value_dim)
86
+ self.agent.critic.value_rms = self.__make_normalizer(self.agent.value_dim)
87
+
88
+ def __make_normalizer(self, num_channels: int):
89
+ return ExponentialMovingNormalizer(num_channels, alpha=self.popart_alpha).to(self.agent.device)
90
+
91
+ def pre_update(self, buffer):
92
+ if not self.recompute:
93
+ self._compute_advantage_and_return(buffer)
94
+
95
+ def objective(self, batch: dict[str, Any]):
96
+ if self.recompute:
97
+ self._compute_advantage_and_return(batch)
98
+
99
+ def post_update(self):
100
+ if self.value_rms is not None:
101
+ old_value_rms: ExponentialMovingNormalizer = self.agent.critic.value_rms
102
+ old_mean, old_std = old_value_rms.mean, old_value_rms.std
103
+ # Adjust value head weights and biases
104
+ new_mean, new_std = self.value_rms.mean, self.value_rms.std
105
+ value_head = self.agent.critic.value_head
106
+ with torch.no_grad():
107
+ value_head.weight.data.mul_(old_std / new_std)
108
+ value_head.bias.data.mul_(old_std).add_(old_mean).sub_(new_mean).div_(new_std)
109
+ old_value_rms.load_state_dict(self.value_rms.state_dict())
110
+
111
+ @torch.no_grad()
112
+ def _compute_advantage_and_return(self, data):
113
+ value = data["value"]
114
+ next_value = data["next_value"]
115
+ if (value_rms := self.agent.critic.value_rms) is not None:
116
+ value = value_rms.unnormalize(value)
117
+ next_value = value_rms.unnormalize(next_value)
118
+
119
+ data["advantage"] = _generalized_advantage_estimation(
120
+ reward=data["reward"],
121
+ done=data["done"],
122
+ value=value,
123
+ next_value=next_value,
124
+ gamma=self.gamma,
125
+ lamda=self.lamda,
126
+ )
127
+
128
+ data["return"] = value + (
129
+ data["advantage"]
130
+ if self.lamda_value is None
131
+ else _generalized_advantage_estimation(
132
+ reward=data["reward"],
133
+ done=data["done"],
134
+ value=value,
135
+ next_value=next_value,
136
+ gamma=self.gamma,
137
+ lamda=self.lamda_value,
138
+ )
139
+ )
140
+
141
+ if value_rms is not None:
142
+ self.value_rms.update(data["return"])
143
+ value_rms.normalize_(data["return"])
cusrl/hook/gradient.py ADDED
@@ -0,0 +1,48 @@
1
+ from torch import nn
2
+
3
+ from cusrl.template import Hook
4
+
5
+ __all__ = ["GradientClipping"]
6
+
7
+
8
+ class GradientClipping(Hook):
9
+ """A hook to clip gradients of model parameters before the optimizer step,
10
+ grouping parameters by name prefixes to handle varying gradient scales
11
+ across modules.
12
+
13
+ Args:
14
+ max_grad_norm (float | None, optional):
15
+ Default max norm for gradient clipping. If None, no clipping is
16
+ applied for the default group. Defaults to 1.0.
17
+ **groups (dict[str, float | None]):
18
+ Keyword arguments mapping parameter name prefixes to specific max
19
+ gradient norms.
20
+ """
21
+
22
+ def __init__(self, max_grad_norm: float | None = 1.0, **groups: float | None):
23
+ groups[""] = max_grad_norm
24
+ for prefix, max_grad_norm in groups.items():
25
+ if max_grad_norm is not None and max_grad_norm < 0:
26
+ raise ValueError(f"'max_grad_norm' for prefix '{prefix}' must be positive.")
27
+ # Sort by length of prefix (longest first for more specific matching)
28
+ self.groups = dict(sorted(groups.items(), key=lambda x: len(x[0]), reverse=True))
29
+
30
+ def pre_optim(self, optimizer):
31
+ prefixed_parameters = {prefix: [] for prefix in self.groups}
32
+ for param_group in optimizer.param_groups:
33
+ params = param_group["params"]
34
+ param_names = param_group.get("param_names", [""] * len(params))
35
+ for param, name in zip(params, param_names, strict=True):
36
+ prefix = self._match_prefix(name)
37
+ prefixed_parameters[prefix].append(param)
38
+ # Clip gradients for each group
39
+ for prefix, params in prefixed_parameters.items():
40
+ if params and (max_grad_norm := self.groups[prefix]) is not None:
41
+ nn.utils.clip_grad_norm_(params, max_grad_norm)
42
+
43
+ def _match_prefix(self, name):
44
+ # Find the longest matching prefix (most specific)
45
+ for prefix in self.groups:
46
+ if name == prefix or name.startswith(f"{prefix}."):
47
+ return prefix
48
+ return ""
@@ -0,0 +1,94 @@
1
+ import itertools
2
+ import math
3
+ from typing import Literal
4
+
5
+ from torch import nn
6
+
7
+ from cusrl.template import ActorCritic, Hook
8
+
9
+ __all__ = ["ModuleInitialization"]
10
+
11
+
12
+ class ModuleInitialization(Hook[ActorCritic]):
13
+ def __init__(
14
+ self,
15
+ scale: float = math.sqrt(2),
16
+ scale_dist: float = math.sqrt(2) * 0.1,
17
+ zero_bias: bool = True,
18
+ conv_a: float = 0.0,
19
+ conv_mode: Literal["fan_in", "fan_out"] = "fan_in",
20
+ conv_nonlinearity: Literal["relu", "leaky_relu"] = "leaky_relu",
21
+ init_actor: bool = True,
22
+ init_critic: bool = True,
23
+ distribution_std: float | None = None,
24
+ ):
25
+ self.scale = scale
26
+ self.scale_dist = scale_dist
27
+ self.zero_bias = zero_bias
28
+ self.conv_a = conv_a
29
+ self.conv_mode = conv_mode
30
+ self.conv_nonlinearity = conv_nonlinearity
31
+ self.init_actor = init_actor
32
+ self.init_critic = init_critic
33
+ self.distribution_std = distribution_std
34
+
35
+ def init(self):
36
+ if self.init_actor:
37
+ for module in itertools.chain(self.agent.actor.modules()):
38
+ self._init_module(module, self.scale, self.zero_bias)
39
+ if self.scale_dist != self.scale:
40
+ self._init_linear(self.agent.actor.distribution.mean_head, self.scale_dist, self.zero_bias)
41
+ if self.distribution_std is not None:
42
+ self.agent.actor.set_distribution_std(self.distribution_std)
43
+ if self.init_critic:
44
+ for module in itertools.chain(self.agent.critic.modules()):
45
+ self._init_module(module, self.scale, self.zero_bias)
46
+
47
+ def _init_module(self, module: nn.Module, scale: float, zero_bias: bool):
48
+ if isinstance(module, nn.Linear):
49
+ self._init_linear(module, scale, zero_bias)
50
+ elif isinstance(module, (nn.LSTM, nn.GRU)):
51
+ self._init_gru_lstm(module, scale, zero_bias)
52
+ elif isinstance(module, nn.MultiheadAttention):
53
+ self._init_mha(module, scale, zero_bias)
54
+ elif isinstance(module, nn.Conv2d):
55
+ self._init_conv2d(module, zero_bias)
56
+
57
+ def _init_linear(self, module: nn.Linear, scale: float, zero_bias: bool):
58
+ nn.init.orthogonal_(module.weight, gain=scale)
59
+ if zero_bias and module.bias is not None:
60
+ nn.init.zeros_(module.bias)
61
+
62
+ def _init_gru_lstm(self, module: nn.GRU | nn.LSTM, scale: float, zero_bias: bool):
63
+ for i in range(module.num_layers):
64
+ nn.init.orthogonal_(getattr(module, f"weight_hh_l{i}"), gain=scale)
65
+ nn.init.orthogonal_(getattr(module, f"weight_ih_l{i}"), gain=scale)
66
+ if zero_bias and getattr(module, f"bias_hh_l{i}") is not None:
67
+ nn.init.zeros_(getattr(module, f"bias_hh_l{i}"))
68
+ nn.init.zeros_(getattr(module, f"bias_ih_l{i}"))
69
+
70
+ def _init_mha(self, module: nn.MultiheadAttention, scale: float, zero_bias: bool):
71
+ if module.in_proj_weight is not None:
72
+ nn.init.orthogonal_(module.in_proj_weight, gain=scale)
73
+ else:
74
+ nn.init.orthogonal_(module.q_proj_weight, gain=scale)
75
+ nn.init.orthogonal_(module.k_proj_weight, gain=scale)
76
+ nn.init.orthogonal_(module.v_proj_weight, gain=scale)
77
+
78
+ if zero_bias:
79
+ if module.in_proj_bias is not None:
80
+ nn.init.zeros_(module.in_proj_bias)
81
+ if module.bias_k is not None:
82
+ nn.init.zeros_(module.bias_k)
83
+ if module.bias_v is not None:
84
+ nn.init.zeros_(module.bias_v)
85
+
86
+ def _init_conv2d(self, module: nn.Conv2d, zero_bias: bool):
87
+ nn.init.kaiming_normal_(
88
+ module.weight,
89
+ a=self.conv_a,
90
+ mode=self.conv_mode,
91
+ nonlinearity=self.conv_nonlinearity,
92
+ )
93
+ if zero_bias and module.bias is not None:
94
+ nn.init.zeros_(module.bias)
@@ -0,0 +1,178 @@
1
+ import math
2
+
3
+ import torch
4
+
5
+ from cusrl.template import ActorCritic, Hook
6
+ from cusrl.utils import distributed
7
+
8
+ __all__ = ["AdaptiveLRSchedule", "MiniBatchWiseLRSchedule", "ThresholdLRSchedule"]
9
+
10
+
11
+ class KLDivergenceBasedLRSchedule(Hook[ActorCritic]):
12
+ MUTABLE_ATTRS = ["desired_kl_divergence"]
13
+
14
+ def __init__(
15
+ self,
16
+ desired_kl_divergence: float = 0.01,
17
+ scale_all_params: bool = False,
18
+ ):
19
+ if desired_kl_divergence <= 0:
20
+ raise ValueError("'desired_kl_divergence' must be positive.")
21
+
22
+ self.desired_kl_divergence = desired_kl_divergence
23
+ self.scale_all_params = scale_all_params
24
+ self._lr_scale = 1.0
25
+
26
+ def post_update(self):
27
+ kl_divergence = self.agent.metrics["kl_divergence"].mean.clone()
28
+ distributed.reduce_mean_(kl_divergence)
29
+ scale = self._compute_scale(kl_divergence.item())
30
+ self._scale_lr_of_parameters(scale)
31
+ self.agent.record(lr_scale=self._lr_scale)
32
+
33
+ def _compute_scale(self, kl_divergence: float) -> float | None:
34
+ raise NotImplementedError
35
+
36
+ def _scale_lr_of_parameters(self, scale: float | None):
37
+ if scale is None or scale == 1.0:
38
+ return
39
+
40
+ self._lr_scale *= scale
41
+ for param_group in self.agent.optimizer.param_groups:
42
+ if self.scale_all_params or any(name.startswith("actor.") for name in param_group["param_names"]):
43
+ param_group["lr"] *= scale
44
+
45
+ def state_dict(self):
46
+ return {"lr_scale": self._lr_scale}
47
+
48
+ def load_state_dict(self, state_dict):
49
+ self._lr_scale = state_dict["lr_scale"]
50
+
51
+
52
+ class ThresholdLRSchedule(KLDivergenceBasedLRSchedule):
53
+ """Adjusts the learning rate based on thresholded KL divergence.
54
+
55
+ Args:
56
+ desired_kl_divergence (float, optional):
57
+ Target KL divergence to maintain. Defaults to 0.01.
58
+ threshold (float, optional):
59
+ Ratio threshold (>1) for deciding when to adjust. Defaults to 1.2.
60
+ scale_factor (float, optional):
61
+ Multiplicative factor (>1) for scaling the LR. Defaults to 1.1.
62
+ scale_all_params (bool, optional):
63
+ If True, scales all optimizer parameter groups; otherwise only
64
+ scales actor parameter groups. Defaults to False.
65
+ """
66
+
67
+ MUTABLE_ATTRS = ["desired_kl_divergence"]
68
+
69
+ def __init__(
70
+ self,
71
+ desired_kl_divergence: float = 0.01,
72
+ threshold: float = 1.2,
73
+ scale_factor: float = 1.1,
74
+ scale_all_params: bool = False,
75
+ ):
76
+ super().__init__(desired_kl_divergence, scale_all_params)
77
+ if threshold <= 1:
78
+ raise ValueError("'threshold' must be greater than 1.")
79
+ if scale_factor <= 1:
80
+ raise ValueError("'scale_factor' must be greater than 1.")
81
+
82
+ self.threshold = threshold
83
+ self.scale_factor = scale_factor
84
+
85
+ def _compute_scale(self, kl_divergence: float):
86
+ if kl_divergence > self.desired_kl_divergence * self.threshold:
87
+ return 1 / self.scale_factor
88
+ if kl_divergence < self.desired_kl_divergence / self.threshold:
89
+ return self.scale_factor
90
+ return None
91
+
92
+
93
+ class AdaptiveLRSchedule(KLDivergenceBasedLRSchedule):
94
+ """Adaptively adjusts the learning rate based on accumulated KL divergence error.
95
+
96
+ Args:
97
+ desired_kl_divergence (float, optional):
98
+ Target KL divergence to maintain. Defaults to 0.01.
99
+ threshold (float, optional):
100
+ Positive threshold for accumulated log-error before scaling. Defaults to 1.0.
101
+ scale_factor (float, optional):
102
+ Positive coefficient controlling adjustment magnitude. Defaults to 0.2.
103
+ scale_all_params (bool, optional):
104
+ If True, scales all optimizer parameter groups; otherwise only scales actor
105
+ parameter groups. Defaults to False.
106
+ """
107
+
108
+ MUTABLE_ATTRS = ["desired_kl_divergence"]
109
+
110
+ def __init__(
111
+ self,
112
+ desired_kl_divergence: float = 0.01,
113
+ threshold: float = 1.0,
114
+ scale_factor: float = 0.2,
115
+ scale_all_params: bool = False,
116
+ ):
117
+ super().__init__(desired_kl_divergence, scale_all_params)
118
+ if threshold <= 0:
119
+ raise ValueError("'threshold' must be positive.")
120
+ if scale_factor <= 0:
121
+ raise ValueError("'scale_factor' must be positive.")
122
+
123
+ self.threshold = threshold
124
+ self.scale_factor = scale_factor
125
+ self.accumulated_log_error = 0.0
126
+ self.count = 0
127
+
128
+ def _compute_scale(self, kl_divergence: float):
129
+ self.accumulated_log_error += math.log(kl_divergence / self.desired_kl_divergence)
130
+ self.count += 1
131
+ if self.threshold > self.accumulated_log_error > -self.threshold:
132
+ return None
133
+ average_log_error = self.accumulated_log_error / self.count
134
+ scale = math.exp(-min(max(average_log_error, -1.0), 1.0) * self.scale_factor)
135
+ self.accumulated_log_error = 0.0
136
+ self.count = 0
137
+ return scale
138
+
139
+
140
+ class MiniBatchWiseLRSchedule(ThresholdLRSchedule):
141
+ """Applies a threshold-based LR schedule on a per-mini-batch KL divergence.
142
+ Modified from (RSL-RL)[https://github.com/leggedrobotics/rsl_rl].
143
+
144
+ Args:
145
+ desired_kl_divergence (float, optional):
146
+ Target KL divergence per mini-batch. Defaults to 0.01.
147
+ threshold (float, optional):
148
+ Ratio threshold for deciding scaling per batch. Defaults to 2.0.
149
+ scale_factor (float, optional):
150
+ Multiplicative factor for scaling the LR. Defaults to 1.5.
151
+ """
152
+
153
+ def __init__(
154
+ self,
155
+ desired_kl_divergence: float = 0.01,
156
+ threshold: float = 2.0,
157
+ scale_factor: float = 1.5,
158
+ ):
159
+ super().__init__(
160
+ desired_kl_divergence,
161
+ threshold=threshold,
162
+ scale_factor=scale_factor,
163
+ scale_all_params=True,
164
+ )
165
+
166
+ def post_init(self):
167
+ self.agent.hook["OnPolicyPreparation"].calculate_kl_divergence = True
168
+
169
+ def post_update(self):
170
+ pass
171
+
172
+ def objective(self, batch):
173
+ with torch.inference_mode():
174
+ kl_divergence = batch["kl_divergence"].mean()
175
+ distributed.reduce_mean_(kl_divergence)
176
+ scale = self._compute_scale(kl_divergence.item())
177
+ self._scale_lr_of_parameters(scale)
178
+ self.agent.record(lr_scale=self._lr_scale)
@@ -0,0 +1,194 @@
1
+ from collections.abc import Sequence
2
+ from typing import Any
3
+
4
+ import numpy as np
5
+ import torch
6
+
7
+ from cusrl.hook.symmetry import SymmetryDef
8
+ from cusrl.template import ActorCritic, Hook
9
+ from cusrl.utils import RunningMeanStd, mean_var_count
10
+ from cusrl.utils.export import ExportSpec
11
+ from cusrl.utils.typing import Slice
12
+
13
+ __all__ = ["ObservationNormalization"]
14
+
15
+
16
+ class ObservationNormalization(Hook[ActorCritic]):
17
+ """Normalizes observations and states using a running mean and standard deviation.
18
+
19
+ This hook maintains a running estimate of the mean and standard deviation for
20
+ observations and, if available, states. It intercepts transitions during data
21
+ collection to normalize the `observation`, `state`, `next_observation`, and
22
+ `next_state` fields. The original, unnormalized values are preserved under
23
+ keys with an "original_" prefix.
24
+
25
+ The running statistics are updated with new data from each step, unless the
26
+ agent is in inference mode or the `frozen` attribute is set to `True`.
27
+
28
+ The hook also handles scenarios where the observation is a subset of the
29
+ state or where there is symmetry in the observations or states. It correctly
30
+ synchronizes statistics across distributed processes. During model export,
31
+ it attaches forward pre-hooks to the actor and critic models to ensure that
32
+ inputs are automatically normalized.
33
+
34
+ Args:
35
+ max_count:
36
+ The maximum count for the running statistics to prevent numerical
37
+ overflow. Defaults to None.
38
+ defer_synchronization:
39
+ If True, synchronization of running statistics in a distributed setting
40
+ is deferred until the end of a rollout. This can improve performance
41
+ by reducing the frequency of synchronization. Defaults to False.
42
+ """
43
+
44
+ observation_rms: RunningMeanStd
45
+ state_rms: RunningMeanStd | None = None
46
+ _mirror_observation: SymmetryDef | None = None
47
+ _mirror_state: SymmetryDef | None = None
48
+ _observation_is_subset_of_state: Slice | torch.Tensor | None = None
49
+
50
+ MODULES = ["observation_rms", "state_rms"]
51
+ MUTABLE_ATTRS = ["frozen"]
52
+
53
+ def __init__(self, max_count: int | None = None, defer_synchronization: bool = False):
54
+ if max_count is not None and max_count <= 0:
55
+ raise ValueError("'max_count' must be positive or None.")
56
+ self.max_count = max_count
57
+ self.frozen: bool = False
58
+ self.defer_synchronization = defer_synchronization
59
+ self._last_done: torch.Tensor | None = None
60
+
61
+ def init(self):
62
+ # Retrieve and normalize the subset index spec
63
+ env_spec = self.agent.environment_spec
64
+ observation_is_subset_of_state = env_spec.observation_is_subset_of_state
65
+ if observation_is_subset_of_state is not None:
66
+ if not self.agent.has_state:
67
+ raise ValueError("'observation_is_subset_of_state' is set but state is not defined.")
68
+ # Convert numpy or list indices to a tensor for consistent indexing
69
+ if isinstance(observation_is_subset_of_state, (np.ndarray, Sequence)):
70
+ observation_is_subset_of_state = self.agent.to_tensor(np.asarray(observation_is_subset_of_state))
71
+ self._observation_is_subset_of_state = observation_is_subset_of_state
72
+
73
+ if self._observation_is_subset_of_state is not None:
74
+ self.observation_rms = self.__make_rms(self.agent.observation_dim)
75
+ else:
76
+ self.observation_rms = self.__make_rms(
77
+ self.agent.observation_dim, self.max_count, env_spec.observation_stat_groups
78
+ )
79
+ if self.agent.has_state:
80
+ self.state_rms = self.__make_rms(self.agent.state_dim, self.max_count, env_spec.state_stat_groups)
81
+ self._mirror_observation = env_spec.mirror_observation
82
+ self._mirror_state = env_spec.mirror_state
83
+
84
+ def pre_act(self, transition: dict):
85
+ observation, state = transition["observation"], transition.get("state")
86
+ if self._last_done is None or not self.agent.environment_spec.final_state_is_missing:
87
+ self.__update_rms(observation, state, self._last_done)
88
+
89
+ transition["original_observation"] = observation
90
+ transition["observation"] = self.observation_rms.normalize(observation)
91
+ if state is not None:
92
+ transition["original_state"] = state
93
+ transition["state"] = self.state_rms.normalize(state)
94
+
95
+ def post_step(self, transition: dict):
96
+ next_observation, next_state = transition["next_observation"], transition.get("next_state")
97
+ self.__update_rms(next_observation, next_state)
98
+ self._last_done = transition["done"].squeeze(-1)
99
+
100
+ transition["original_next_observation"] = next_observation
101
+ transition["next_observation"] = self.observation_rms.normalize(next_observation)
102
+ if next_state is not None:
103
+ transition["original_next_state"] = next_state
104
+ transition["next_state"] = self.state_rms.normalize(next_state)
105
+
106
+ def __make_rms(
107
+ self,
108
+ num_channels: int,
109
+ max_count: int | None = None,
110
+ stat_groups: tuple[tuple[int, int], ...] = (),
111
+ ):
112
+ normalizer = RunningMeanStd(num_channels, max_count=max_count).to(self.agent.device)
113
+ for group in stat_groups:
114
+ normalizer.register_stat_group(*group)
115
+ return normalizer
116
+
117
+ def __update_rms(
118
+ self,
119
+ observation: torch.Tensor,
120
+ state: torch.Tensor | None,
121
+ indices: torch.Tensor | None = None,
122
+ ):
123
+ if self.agent.inference_mode or self.frozen:
124
+ # Do not update the statistics during inference or if frozen
125
+ return
126
+
127
+ if state is not None:
128
+ self.__update_rms_impl(state, self.state_rms, self._mirror_state, indices)
129
+ if self._observation_is_subset_of_state is not None:
130
+ self.observation_rms.mean.copy_(self.state_rms.mean[self._observation_is_subset_of_state])
131
+ self.observation_rms.var.copy_(self.state_rms.var[self._observation_is_subset_of_state])
132
+ self.observation_rms.std.copy_(self.state_rms.std[self._observation_is_subset_of_state])
133
+ self.observation_rms.count = self.state_rms.count
134
+ else:
135
+ self.__update_rms_impl(observation, self.observation_rms, self._mirror_observation, indices)
136
+
137
+ def __update_rms_impl(
138
+ self,
139
+ observation: torch.Tensor,
140
+ rms: RunningMeanStd,
141
+ mirror: SymmetryDef | None = None,
142
+ indices: torch.Tensor | None = None,
143
+ ):
144
+ if indices is not None:
145
+ observation = observation[indices]
146
+ mean, var, count = mean_var_count(observation)
147
+ if mirror is not None:
148
+ mirrored_mean = mirror(mean)
149
+ mirrored_var = abs(mirror(var))
150
+ var = (var + mirrored_var) / 2 + (mean - mirrored_mean) ** 2 / 4
151
+ mean = (mean + mirrored_mean) / 2
152
+ rms.update_from_stats(mean, var, count, synchronize=not self.defer_synchronization)
153
+
154
+ def pre_update(self, buffer):
155
+ if self.defer_synchronization:
156
+ if self.state_rms is not None:
157
+ self.state_rms.synchronize()
158
+ if self._observation_is_subset_of_state is not None:
159
+ self.observation_rms.mean.copy_(self.state_rms.mean[self._observation_is_subset_of_state])
160
+ self.observation_rms.var.copy_(self.state_rms.var[self._observation_is_subset_of_state])
161
+ self.observation_rms.std.copy_(self.state_rms.std[self._observation_is_subset_of_state])
162
+ self.observation_rms.count = self.state_rms.count
163
+ else:
164
+ self.observation_rms.synchronize()
165
+
166
+ def export(self, export_data: dict[str, ExportSpec]):
167
+ export_data["actor"].module.register_forward_pre_hook(
168
+ self.__normalize_observation_forward_pre_hook, with_kwargs=True
169
+ )
170
+ if "critic" in export_data:
171
+ export_data["critic"].module.register_forward_pre_hook(
172
+ (
173
+ self.__normalize_state_forward_pre_hook
174
+ if self.agent.has_state
175
+ else self.__normalize_observation_forward_pre_hook
176
+ ),
177
+ with_kwargs=True,
178
+ )
179
+
180
+ @torch.no_grad()
181
+ def __normalize_observation_forward_pre_hook(self, module, args: tuple, kwargs: dict[str, Any]):
182
+ if "observation" in kwargs:
183
+ kwargs["observation"] = self.observation_rms.normalize(kwargs["observation"])
184
+ return args, kwargs
185
+ normalized_observation = self.observation_rms.normalize(args[0])
186
+ return (normalized_observation, *args[1:]), kwargs
187
+
188
+ @torch.no_grad()
189
+ def __normalize_state_forward_pre_hook(self, module, args: tuple, kwargs: dict[str, Any]):
190
+ if "state" in kwargs:
191
+ kwargs["state"] = self.state_rms.normalize(kwargs["state"])
192
+ return args, kwargs
193
+ normalized_state = self.state_rms.normalize(args[0])
194
+ return (normalized_state, *args[1:]), kwargs