dreamer4 0.0.101__tar.gz → 0.0.102__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dreamer4 might be problematic. Click here for more details.
- {dreamer4-0.0.101 → dreamer4-0.0.102}/PKG-INFO +1 -1
- {dreamer4-0.0.101 → dreamer4-0.0.102}/dreamer4/dreamer4.py +9 -4
- {dreamer4-0.0.101 → dreamer4-0.0.102}/pyproject.toml +1 -1
- {dreamer4-0.0.101 → dreamer4-0.0.102}/.github/workflows/python-publish.yml +0 -0
- {dreamer4-0.0.101 → dreamer4-0.0.102}/.github/workflows/test.yml +0 -0
- {dreamer4-0.0.101 → dreamer4-0.0.102}/.gitignore +0 -0
- {dreamer4-0.0.101 → dreamer4-0.0.102}/LICENSE +0 -0
- {dreamer4-0.0.101 → dreamer4-0.0.102}/README.md +0 -0
- {dreamer4-0.0.101 → dreamer4-0.0.102}/dreamer4/__init__.py +0 -0
- {dreamer4-0.0.101 → dreamer4-0.0.102}/dreamer4/mocks.py +0 -0
- {dreamer4-0.0.101 → dreamer4-0.0.102}/dreamer4/trainers.py +0 -0
- {dreamer4-0.0.101 → dreamer4-0.0.102}/dreamer4-fig2.png +0 -0
- {dreamer4-0.0.101 → dreamer4-0.0.102}/tests/test_dreamer.py +0 -0
|
@@ -1902,6 +1902,7 @@ class DynamicsWorldModel(Module):
|
|
|
1902
1902
|
pmpo_pos_to_neg_weight = 0.5, # pos and neg equal weight
|
|
1903
1903
|
pmpo_reverse_kl = True,
|
|
1904
1904
|
pmpo_kl_div_loss_weight = .3,
|
|
1905
|
+
normalize_advantages = None,
|
|
1905
1906
|
value_clip = 0.4,
|
|
1906
1907
|
policy_entropy_weight = .01,
|
|
1907
1908
|
gae_use_accelerated = False
|
|
@@ -2425,6 +2426,7 @@ class DynamicsWorldModel(Module):
|
|
|
2425
2426
|
value_optim: Optimizer | None = None,
|
|
2426
2427
|
only_learn_policy_value_heads = True, # in the paper, they do not finetune the entire dynamics model, they just learn the heads
|
|
2427
2428
|
use_pmpo = True,
|
|
2429
|
+
normalize_advantages = None,
|
|
2428
2430
|
eps = 1e-6
|
|
2429
2431
|
):
|
|
2430
2432
|
|
|
@@ -2507,16 +2509,19 @@ class DynamicsWorldModel(Module):
|
|
|
2507
2509
|
else:
|
|
2508
2510
|
advantage = returns - old_values
|
|
2509
2511
|
|
|
2510
|
-
#
|
|
2512
|
+
# if using pmpo, do not normalize advantages, but can be overridden
|
|
2513
|
+
|
|
2514
|
+
normalize_advantages = default(normalize_advantages, not use_pmpo)
|
|
2515
|
+
|
|
2516
|
+
if normalize_advantages:
|
|
2517
|
+
advantage = F.layer_norm(advantage, advantage.shape, eps = eps)
|
|
2518
|
+
|
|
2511
2519
|
# https://arxiv.org/abs/2410.04166v1
|
|
2512
2520
|
|
|
2513
2521
|
if use_pmpo:
|
|
2514
2522
|
pos_advantage_mask = advantage >= 0.
|
|
2515
2523
|
neg_advantage_mask = ~pos_advantage_mask
|
|
2516
2524
|
|
|
2517
|
-
else:
|
|
2518
|
-
advantage = F.layer_norm(advantage, advantage.shape, eps = eps)
|
|
2519
|
-
|
|
2520
2525
|
# replay for the action logits and values
|
|
2521
2526
|
# but only do so if fine tuning the entire world model for RL
|
|
2522
2527
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|