dreamer4 0.0.80__tar.gz → 0.0.81__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dreamer4 might be problematic. Click here for more details.

@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dreamer4
3
- Version: 0.0.80
3
+ Version: 0.0.81
4
4
  Summary: Dreamer 4
5
5
  Project-URL: Homepage, https://pypi.org/project/dreamer4/
6
6
  Project-URL: Repository, https://github.com/lucidrains/dreamer4
@@ -1817,6 +1817,8 @@ class DynamicsWorldModel(Module):
1817
1817
  continuous_action_loss_weight: float | list[float] = 1.,
1818
1818
  num_latent_genes = 0, # for carrying out evolution within the dreams https://web3.arxiv.org/abs/2503.19037
1819
1819
  num_residual_streams = 1,
1820
+ keep_reward_ema_stats = False,
1821
+ reward_ema_decay = 0.998,
1820
1822
  gae_discount_factor = 0.997,
1821
1823
  gae_lambda = 0.95,
1822
1824
  ppo_eps_clip = 0.2,
@@ -2022,6 +2024,14 @@ class DynamicsWorldModel(Module):
2022
2024
  self.value_clip = value_clip
2023
2025
  self.policy_entropy_weight = value_clip
2024
2026
 
2027
+ # rewards related
2028
+
2029
+ self.keep_reward_ema_stats = keep_reward_ema_stats
2030
+ self.reward_ema_decay = reward_ema_decay
2031
+
2032
+ self.register_buffer('ema_returns_mean', tensor(0.))
2033
+ self.register_buffer('ema_returns_var', tensor(1.))
2034
+
2025
2035
  # loss related
2026
2036
 
2027
2037
  self.flow_loss_normalizer = LossNormalizer(1)
@@ -2267,11 +2277,32 @@ class DynamicsWorldModel(Module):
2267
2277
 
2268
2278
  world_model_forward_context = torch.no_grad if only_learn_policy_value_heads else nullcontext
2269
2279
 
2280
+ # maybe keep track returns statistics and normalize returns and values before calculating advantage, as done in dreamer v3
2281
+
2282
+ if self.keep_reward_ema_stats:
2283
+ ema_returns_mean, ema_returns_var = self.ema_returns_mean, self.ema_returns_var
2284
+
2285
+ decay = 1. - self.reward_ema_decay
2286
+
2287
+ # todo - handle distributed
2288
+
2289
+ returns_mean, returns_var = returns.mean(), returns.var()
2290
+
2291
+ ema_returns_mean.lerp_(returns_mean, decay)
2292
+ ema_returns_var.lerp_(returns_var, decay)
2293
+
2294
+ ema_returns_std = ema_returns_var.clamp(min = 1e-5).sqrt()
2295
+
2296
+ normed_returns = (returns - ema_returns_mean) / ema_returns_std
2297
+ normed_old_values = (old_values - ema_returns_mean) / ema_returns_std
2298
+
2299
+ advantage = normed_returns - normed_old_values
2300
+ else:
2301
+ advantage = returns - old_values
2302
+
2270
2303
  # apparently they just use the sign of the advantage
2271
2304
  # https://arxiv.org/abs/2410.04166v1
2272
2305
 
2273
- advantage = returns - old_values
2274
-
2275
2306
  if use_signed_advantage:
2276
2307
  advantage = advantage.sign()
2277
2308
  else:
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "dreamer4"
3
- version = "0.0.80"
3
+ version = "0.0.81"
4
4
  description = "Dreamer 4"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
File without changes
File without changes
File without changes
File without changes
File without changes