dreamer4 0.0.100__py3-none-any.whl → 0.0.102__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dreamer4/dreamer4.py +19 -5
- {dreamer4-0.0.100.dist-info → dreamer4-0.0.102.dist-info}/METADATA +1 -1
- dreamer4-0.0.102.dist-info/RECORD +8 -0
- dreamer4-0.0.100.dist-info/RECORD +0 -8
- {dreamer4-0.0.100.dist-info → dreamer4-0.0.102.dist-info}/WHEEL +0 -0
- {dreamer4-0.0.100.dist-info → dreamer4-0.0.102.dist-info}/licenses/LICENSE +0 -0
dreamer4/dreamer4.py
CHANGED
|
@@ -1900,7 +1900,9 @@ class DynamicsWorldModel(Module):
|
|
|
1900
1900
|
gae_lambda = 0.95,
|
|
1901
1901
|
ppo_eps_clip = 0.2,
|
|
1902
1902
|
pmpo_pos_to_neg_weight = 0.5, # pos and neg equal weight
|
|
1903
|
+
pmpo_reverse_kl = True,
|
|
1903
1904
|
pmpo_kl_div_loss_weight = .3,
|
|
1905
|
+
normalize_advantages = None,
|
|
1904
1906
|
value_clip = 0.4,
|
|
1905
1907
|
policy_entropy_weight = .01,
|
|
1906
1908
|
gae_use_accelerated = False
|
|
@@ -2108,6 +2110,7 @@ class DynamicsWorldModel(Module):
|
|
|
2108
2110
|
|
|
2109
2111
|
self.pmpo_pos_to_neg_weight = pmpo_pos_to_neg_weight
|
|
2110
2112
|
self.pmpo_kl_div_loss_weight = pmpo_kl_div_loss_weight
|
|
2113
|
+
self.pmpo_reverse_kl = pmpo_reverse_kl
|
|
2111
2114
|
|
|
2112
2115
|
# rewards related
|
|
2113
2116
|
|
|
@@ -2423,6 +2426,7 @@ class DynamicsWorldModel(Module):
|
|
|
2423
2426
|
value_optim: Optimizer | None = None,
|
|
2424
2427
|
only_learn_policy_value_heads = True, # in the paper, they do not finetune the entire dynamics model, they just learn the heads
|
|
2425
2428
|
use_pmpo = True,
|
|
2429
|
+
normalize_advantages = None,
|
|
2426
2430
|
eps = 1e-6
|
|
2427
2431
|
):
|
|
2428
2432
|
|
|
@@ -2505,16 +2509,19 @@ class DynamicsWorldModel(Module):
|
|
|
2505
2509
|
else:
|
|
2506
2510
|
advantage = returns - old_values
|
|
2507
2511
|
|
|
2508
|
-
#
|
|
2512
|
+
# if using pmpo, do not normalize advantages, but can be overridden
|
|
2513
|
+
|
|
2514
|
+
normalize_advantages = default(normalize_advantages, not use_pmpo)
|
|
2515
|
+
|
|
2516
|
+
if normalize_advantages:
|
|
2517
|
+
advantage = F.layer_norm(advantage, advantage.shape, eps = eps)
|
|
2518
|
+
|
|
2509
2519
|
# https://arxiv.org/abs/2410.04166v1
|
|
2510
2520
|
|
|
2511
2521
|
if use_pmpo:
|
|
2512
2522
|
pos_advantage_mask = advantage >= 0.
|
|
2513
2523
|
neg_advantage_mask = ~pos_advantage_mask
|
|
2514
2524
|
|
|
2515
|
-
else:
|
|
2516
|
-
advantage = F.layer_norm(advantage, advantage.shape, eps = eps)
|
|
2517
|
-
|
|
2518
2525
|
# replay for the action logits and values
|
|
2519
2526
|
# but only do so if fine tuning the entire world model for RL
|
|
2520
2527
|
|
|
@@ -2578,11 +2585,18 @@ class DynamicsWorldModel(Module):
|
|
|
2578
2585
|
# take care of kl
|
|
2579
2586
|
|
|
2580
2587
|
if self.pmpo_kl_div_loss_weight > 0.:
|
|
2588
|
+
|
|
2581
2589
|
new_unembedded_actions = self.action_embedder.unembed(policy_embed, pred_head_index = 0)
|
|
2582
2590
|
|
|
2591
|
+
kl_div_inputs, kl_div_targets = new_unembedded_actions, old_action_unembeds
|
|
2592
|
+
|
|
2583
2593
|
# mentioned that the "reverse direction for the prior KL" was used
|
|
2594
|
+
# make optional, as observed instability in toy task
|
|
2595
|
+
|
|
2596
|
+
if self.pmpo_reverse_kl:
|
|
2597
|
+
kl_div_inputs, kl_div_targets = kl_div_targets, kl_div_inputs
|
|
2584
2598
|
|
|
2585
|
-
discrete_kl_div, continuous_kl_div = self.action_embedder.kl_div(
|
|
2599
|
+
discrete_kl_div, continuous_kl_div = self.action_embedder.kl_div(kl_div_inputs, kl_div_targets)
|
|
2586
2600
|
|
|
2587
2601
|
# accumulate discrete and continuous kl div
|
|
2588
2602
|
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
dreamer4/__init__.py,sha256=Jssh1obzDRtTfBLZl36kXge1cIQlMjf_8DyjPulvKSk,183
|
|
2
|
+
dreamer4/dreamer4.py,sha256=3qeVN3qdvx7iPxA0OBXw_yy5Re6rX6FIKITH9bp6RBs,119202
|
|
3
|
+
dreamer4/mocks.py,sha256=TfqOB_Gq6N_GggBYwa6ZAJQx38ntlYbXZe23Ne4jshw,2502
|
|
4
|
+
dreamer4/trainers.py,sha256=JsnJwQJcbI_75KBTNddG6b7QVkO6LD1N_HQiVe-VnCM,15087
|
|
5
|
+
dreamer4-0.0.102.dist-info/METADATA,sha256=xxVL1sFimb0azSD5sDOEzugY7rBT6oDek4YdiIS8m18,3066
|
|
6
|
+
dreamer4-0.0.102.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
+
dreamer4-0.0.102.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
+
dreamer4-0.0.102.dist-info/RECORD,,
|
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
dreamer4/__init__.py,sha256=Jssh1obzDRtTfBLZl36kXge1cIQlMjf_8DyjPulvKSk,183
|
|
2
|
-
dreamer4/dreamer4.py,sha256=9r2qDg6SpCe6Y2MWzI44o369t1a4b_LhfQSI_FK5WHQ,118665
|
|
3
|
-
dreamer4/mocks.py,sha256=TfqOB_Gq6N_GggBYwa6ZAJQx38ntlYbXZe23Ne4jshw,2502
|
|
4
|
-
dreamer4/trainers.py,sha256=JsnJwQJcbI_75KBTNddG6b7QVkO6LD1N_HQiVe-VnCM,15087
|
|
5
|
-
dreamer4-0.0.100.dist-info/METADATA,sha256=-hOF9eyycsndS5u8-i6o9IikCDracHn0mIIv_g5dLRo,3066
|
|
6
|
-
dreamer4-0.0.100.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
-
dreamer4-0.0.100.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
-
dreamer4-0.0.100.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|