dreamer4 0.0.100__py3-none-any.whl → 0.0.102__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dreamer4/dreamer4.py CHANGED
@@ -1900,7 +1900,9 @@ class DynamicsWorldModel(Module):
1900
1900
  gae_lambda = 0.95,
1901
1901
  ppo_eps_clip = 0.2,
1902
1902
  pmpo_pos_to_neg_weight = 0.5, # pos and neg equal weight
1903
+ pmpo_reverse_kl = True,
1903
1904
  pmpo_kl_div_loss_weight = .3,
1905
+ normalize_advantages = None,
1904
1906
  value_clip = 0.4,
1905
1907
  policy_entropy_weight = .01,
1906
1908
  gae_use_accelerated = False
@@ -2108,6 +2110,7 @@ class DynamicsWorldModel(Module):
2108
2110
 
2109
2111
  self.pmpo_pos_to_neg_weight = pmpo_pos_to_neg_weight
2110
2112
  self.pmpo_kl_div_loss_weight = pmpo_kl_div_loss_weight
2113
+ self.pmpo_reverse_kl = pmpo_reverse_kl
2111
2114
 
2112
2115
  # rewards related
2113
2116
 
@@ -2423,6 +2426,7 @@ class DynamicsWorldModel(Module):
2423
2426
  value_optim: Optimizer | None = None,
2424
2427
  only_learn_policy_value_heads = True, # in the paper, they do not finetune the entire dynamics model, they just learn the heads
2425
2428
  use_pmpo = True,
2429
+ normalize_advantages = None,
2426
2430
  eps = 1e-6
2427
2431
  ):
2428
2432
 
@@ -2505,16 +2509,19 @@ class DynamicsWorldModel(Module):
2505
2509
  else:
2506
2510
  advantage = returns - old_values
2507
2511
 
2508
- # apparently they just use the sign of the advantage
2512
+ # if using pmpo, do not normalize advantages, but can be overridden
2513
+
2514
+ normalize_advantages = default(normalize_advantages, not use_pmpo)
2515
+
2516
+ if normalize_advantages:
2517
+ advantage = F.layer_norm(advantage, advantage.shape, eps = eps)
2518
+
2509
2519
  # https://arxiv.org/abs/2410.04166v1
2510
2520
 
2511
2521
  if use_pmpo:
2512
2522
  pos_advantage_mask = advantage >= 0.
2513
2523
  neg_advantage_mask = ~pos_advantage_mask
2514
2524
 
2515
- else:
2516
- advantage = F.layer_norm(advantage, advantage.shape, eps = eps)
2517
-
2518
2525
  # replay for the action logits and values
2519
2526
  # but only do so if fine tuning the entire world model for RL
2520
2527
 
@@ -2578,11 +2585,18 @@ class DynamicsWorldModel(Module):
2578
2585
  # take care of kl
2579
2586
 
2580
2587
  if self.pmpo_kl_div_loss_weight > 0.:
2588
+
2581
2589
  new_unembedded_actions = self.action_embedder.unembed(policy_embed, pred_head_index = 0)
2582
2590
 
2591
+ kl_div_inputs, kl_div_targets = new_unembedded_actions, old_action_unembeds
2592
+
2583
2593
  # mentioned that the "reverse direction for the prior KL" was used
2594
+ # make optional, as observed instability in toy task
2595
+
2596
+ if self.pmpo_reverse_kl:
2597
+ kl_div_inputs, kl_div_targets = kl_div_targets, kl_div_inputs
2584
2598
 
2585
- discrete_kl_div, continuous_kl_div = self.action_embedder.kl_div(old_action_unembeds, new_unembedded_actions)
2599
+ discrete_kl_div, continuous_kl_div = self.action_embedder.kl_div(kl_div_inputs, kl_div_targets)
2586
2600
 
2587
2601
  # accumulate discrete and continuous kl div
2588
2602
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dreamer4
3
- Version: 0.0.100
3
+ Version: 0.0.102
4
4
  Summary: Dreamer 4
5
5
  Project-URL: Homepage, https://pypi.org/project/dreamer4/
6
6
  Project-URL: Repository, https://github.com/lucidrains/dreamer4
@@ -0,0 +1,8 @@
1
+ dreamer4/__init__.py,sha256=Jssh1obzDRtTfBLZl36kXge1cIQlMjf_8DyjPulvKSk,183
2
+ dreamer4/dreamer4.py,sha256=3qeVN3qdvx7iPxA0OBXw_yy5Re6rX6FIKITH9bp6RBs,119202
3
+ dreamer4/mocks.py,sha256=TfqOB_Gq6N_GggBYwa6ZAJQx38ntlYbXZe23Ne4jshw,2502
4
+ dreamer4/trainers.py,sha256=JsnJwQJcbI_75KBTNddG6b7QVkO6LD1N_HQiVe-VnCM,15087
5
+ dreamer4-0.0.102.dist-info/METADATA,sha256=xxVL1sFimb0azSD5sDOEzugY7rBT6oDek4YdiIS8m18,3066
6
+ dreamer4-0.0.102.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ dreamer4-0.0.102.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ dreamer4-0.0.102.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- dreamer4/__init__.py,sha256=Jssh1obzDRtTfBLZl36kXge1cIQlMjf_8DyjPulvKSk,183
2
- dreamer4/dreamer4.py,sha256=9r2qDg6SpCe6Y2MWzI44o369t1a4b_LhfQSI_FK5WHQ,118665
3
- dreamer4/mocks.py,sha256=TfqOB_Gq6N_GggBYwa6ZAJQx38ntlYbXZe23Ne4jshw,2502
4
- dreamer4/trainers.py,sha256=JsnJwQJcbI_75KBTNddG6b7QVkO6LD1N_HQiVe-VnCM,15087
5
- dreamer4-0.0.100.dist-info/METADATA,sha256=-hOF9eyycsndS5u8-i6o9IikCDracHn0mIIv_g5dLRo,3066
6
- dreamer4-0.0.100.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- dreamer4-0.0.100.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- dreamer4-0.0.100.dist-info/RECORD,,