dreamer4 0.0.90__py3-none-any.whl → 0.0.92__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dreamer4/dreamer4.py +105 -22
- {dreamer4-0.0.90.dist-info → dreamer4-0.0.92.dist-info}/METADATA +1 -1
- dreamer4-0.0.92.dist-info/RECORD +8 -0
- dreamer4-0.0.90.dist-info/RECORD +0 -8
- {dreamer4-0.0.90.dist-info → dreamer4-0.0.92.dist-info}/WHEEL +0 -0
- {dreamer4-0.0.90.dist-info → dreamer4-0.0.92.dist-info}/licenses/LICENSE +0 -0
dreamer4/dreamer4.py
CHANGED
|
@@ -11,7 +11,7 @@ from dataclasses import dataclass, asdict
|
|
|
11
11
|
import torch
|
|
12
12
|
import torch.nn.functional as F
|
|
13
13
|
from torch.nested import nested_tensor
|
|
14
|
-
from torch.distributions import Normal
|
|
14
|
+
from torch.distributions import Normal, kl
|
|
15
15
|
from torch.nn import Module, ModuleList, Embedding, Parameter, Sequential, Linear, RMSNorm, Identity
|
|
16
16
|
from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, full, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange
|
|
17
17
|
from torch.utils._pytree import tree_flatten, tree_unflatten
|
|
@@ -186,9 +186,26 @@ def lens_to_mask(t, max_len = None):
|
|
|
186
186
|
|
|
187
187
|
return einx.less('j, i -> i j', seq, t)
|
|
188
188
|
|
|
189
|
+
def masked_mean(t, mask = None):
|
|
190
|
+
if not exists(mask):
|
|
191
|
+
return t.mean()
|
|
192
|
+
|
|
193
|
+
if not mask.any():
|
|
194
|
+
return t[mask].sum()
|
|
195
|
+
|
|
196
|
+
return t[mask].mean()
|
|
197
|
+
|
|
189
198
|
def log(t, eps = 1e-20):
|
|
190
199
|
return t.clamp(min = eps).log()
|
|
191
200
|
|
|
201
|
+
def mean_log_var_to_distr(
|
|
202
|
+
mean_log_var: Tensor
|
|
203
|
+
) -> Normal:
|
|
204
|
+
|
|
205
|
+
mean, log_var = mean_log_var.unbind(dim = -1)
|
|
206
|
+
std = (0.5 * log_var).exp()
|
|
207
|
+
return Normal(mean, std)
|
|
208
|
+
|
|
192
209
|
def safe_cat(tensors, dim):
|
|
193
210
|
tensors = [*filter(exists, tensors)]
|
|
194
211
|
|
|
@@ -815,10 +832,7 @@ class ActionEmbedder(Module):
|
|
|
815
832
|
continuous_entropies = None
|
|
816
833
|
|
|
817
834
|
if exists(continuous_targets):
|
|
818
|
-
|
|
819
|
-
std = (0.5 * log_var).exp()
|
|
820
|
-
|
|
821
|
-
distr = Normal(mean, std)
|
|
835
|
+
distr = mean_log_var_to_distr(continuous_action_mean_log_var)
|
|
822
836
|
continuous_log_probs = distr.log_prob(continuous_targets)
|
|
823
837
|
|
|
824
838
|
if return_entropies:
|
|
@@ -833,6 +847,54 @@ class ActionEmbedder(Module):
|
|
|
833
847
|
|
|
834
848
|
return log_probs, entropies
|
|
835
849
|
|
|
850
|
+
def kl_div(
|
|
851
|
+
self,
|
|
852
|
+
src: tuple[Tensor | None, Tensor | None],
|
|
853
|
+
tgt: tuple[Tensor | None, Tensor | None]
|
|
854
|
+
) -> tuple[Tensor | None, Tensor | None]:
|
|
855
|
+
|
|
856
|
+
src_discrete, src_continuous = src
|
|
857
|
+
tgt_discrete, tgt_continuous = tgt
|
|
858
|
+
|
|
859
|
+
discrete_kl_div = None
|
|
860
|
+
|
|
861
|
+
# split discrete if it is not already (multiple discrete actions)
|
|
862
|
+
|
|
863
|
+
if exists(src_discrete):
|
|
864
|
+
|
|
865
|
+
discrete_split = self.num_discrete_actions.tolist()
|
|
866
|
+
|
|
867
|
+
if is_tensor(src_discrete):
|
|
868
|
+
src_discrete = src_discrete.split(discrete_split, dim = -1)
|
|
869
|
+
|
|
870
|
+
if is_tensor(tgt_discrete):
|
|
871
|
+
tgt_discrete = tgt_discrete.split(discrete_split, dim = -1)
|
|
872
|
+
|
|
873
|
+
discrete_kl_divs = []
|
|
874
|
+
|
|
875
|
+
for src_logit, tgt_logit in zip(src_discrete, tgt_discrete):
|
|
876
|
+
|
|
877
|
+
src_log_probs = src_logit.log_softmax(dim = -1)
|
|
878
|
+
tgt_prob = tgt_logit.softmax(dim = -1)
|
|
879
|
+
|
|
880
|
+
one_discrete_kl_div = F.kl_div(src_log_probs, tgt_prob, reduction = 'none')
|
|
881
|
+
|
|
882
|
+
discrete_kl_divs.append(one_discrete_kl_div.sum(dim = -1))
|
|
883
|
+
|
|
884
|
+
discrete_kl_div = stack(discrete_kl_divs, dim = -1)
|
|
885
|
+
|
|
886
|
+
# calculate kl divergence for continuous
|
|
887
|
+
|
|
888
|
+
continuous_kl_div = None
|
|
889
|
+
|
|
890
|
+
if exists(src_continuous):
|
|
891
|
+
src_normal = mean_log_var_to_distr(src_continuous)
|
|
892
|
+
tgt_normal = mean_log_var_to_distr(tgt_continuous)
|
|
893
|
+
|
|
894
|
+
continuous_kl_div = kl.kl_divergence(src_normal, tgt_normal)
|
|
895
|
+
|
|
896
|
+
return discrete_kl_div, continuous_kl_div
|
|
897
|
+
|
|
836
898
|
def forward(
|
|
837
899
|
self,
|
|
838
900
|
*,
|
|
@@ -1824,6 +1886,7 @@ class DynamicsWorldModel(Module):
|
|
|
1824
1886
|
gae_discount_factor = 0.997,
|
|
1825
1887
|
gae_lambda = 0.95,
|
|
1826
1888
|
ppo_eps_clip = 0.2,
|
|
1889
|
+
pmpo_pos_to_neg_weight = 0.5, # pos and neg equal weight
|
|
1827
1890
|
value_clip = 0.4,
|
|
1828
1891
|
policy_entropy_weight = .01,
|
|
1829
1892
|
gae_use_accelerated = False
|
|
@@ -2027,6 +2090,10 @@ class DynamicsWorldModel(Module):
|
|
|
2027
2090
|
self.value_clip = value_clip
|
|
2028
2091
|
self.policy_entropy_weight = value_clip
|
|
2029
2092
|
|
|
2093
|
+
# pmpo related
|
|
2094
|
+
|
|
2095
|
+
self.pmpo_pos_to_neg_weight = pmpo_pos_to_neg_weight
|
|
2096
|
+
|
|
2030
2097
|
# rewards related
|
|
2031
2098
|
|
|
2032
2099
|
self.keep_reward_ema_stats = keep_reward_ema_stats
|
|
@@ -2334,7 +2401,7 @@ class DynamicsWorldModel(Module):
|
|
|
2334
2401
|
policy_optim: Optimizer | None = None,
|
|
2335
2402
|
value_optim: Optimizer | None = None,
|
|
2336
2403
|
only_learn_policy_value_heads = True, # in the paper, they do not finetune the entire dynamics model, they just learn the heads
|
|
2337
|
-
|
|
2404
|
+
use_pmpo = True,
|
|
2338
2405
|
eps = 1e-6
|
|
2339
2406
|
):
|
|
2340
2407
|
|
|
@@ -2374,6 +2441,8 @@ class DynamicsWorldModel(Module):
|
|
|
2374
2441
|
max_time = latents.shape[1]
|
|
2375
2442
|
is_var_len = exists(experience.lens)
|
|
2376
2443
|
|
|
2444
|
+
mask = None
|
|
2445
|
+
|
|
2377
2446
|
if is_var_len:
|
|
2378
2447
|
learnable_lens = experience.lens - experience.is_truncated.long() # if is truncated, remove the last one, as it is bootstrapped value
|
|
2379
2448
|
mask = lens_to_mask(learnable_lens, max_time)
|
|
@@ -2417,8 +2486,9 @@ class DynamicsWorldModel(Module):
|
|
|
2417
2486
|
# apparently they just use the sign of the advantage
|
|
2418
2487
|
# https://arxiv.org/abs/2410.04166v1
|
|
2419
2488
|
|
|
2420
|
-
if
|
|
2421
|
-
|
|
2489
|
+
if use_pmpo:
|
|
2490
|
+
pos_advantage_mask = advantage >= 0.
|
|
2491
|
+
neg_advantage_mask = ~pos_advantage_mask
|
|
2422
2492
|
else:
|
|
2423
2493
|
advantage = F.layer_norm(advantage, advantage.shape, eps = eps)
|
|
2424
2494
|
|
|
@@ -2464,35 +2534,48 @@ class DynamicsWorldModel(Module):
|
|
|
2464
2534
|
log_probs = safe_cat(log_probs, dim = -1)
|
|
2465
2535
|
entropies = safe_cat(entropies, dim = -1)
|
|
2466
2536
|
|
|
2467
|
-
|
|
2537
|
+
advantage = rearrange(advantage, '... -> ... 1') # broadcast across all actions
|
|
2468
2538
|
|
|
2469
|
-
|
|
2539
|
+
if use_pmpo:
|
|
2540
|
+
# pmpo - weighting the positive and negative advantages equally - ignoring magnitude of advantage and taking the sign
|
|
2541
|
+
# seems to be weighted across batch and time, iiuc
|
|
2542
|
+
# eq (10) in https://arxiv.org/html/2410.04166v1
|
|
2470
2543
|
|
|
2471
|
-
|
|
2544
|
+
if exists(mask):
|
|
2545
|
+
pos_advantage_mask &= mask
|
|
2546
|
+
neg_advantage_mask &= mask
|
|
2472
2547
|
|
|
2473
|
-
|
|
2548
|
+
α = self.pmpo_pos_to_neg_weight
|
|
2474
2549
|
|
|
2475
|
-
|
|
2476
|
-
|
|
2550
|
+
pos = masked_mean(log_probs, pos_advantage_mask)
|
|
2551
|
+
neg = -masked_mean(log_probs, neg_advantage_mask)
|
|
2477
2552
|
|
|
2478
|
-
|
|
2553
|
+
policy_loss = -(α * pos + (1. - α) * neg)
|
|
2554
|
+
|
|
2555
|
+
else:
|
|
2556
|
+
# ppo clipped surrogate loss
|
|
2557
|
+
|
|
2558
|
+
ratio = (log_probs - old_log_probs).exp()
|
|
2559
|
+
clipped_ratio = ratio.clamp(1. - self.ppo_eps_clip, 1. + self.ppo_eps_clip)
|
|
2560
|
+
|
|
2561
|
+
policy_loss = -torch.min(ratio * advantage, clipped_ratio * advantage)
|
|
2562
|
+
policy_loss = reduce(policy_loss, 'b t na -> b t', 'sum')
|
|
2563
|
+
|
|
2564
|
+
policy_loss = masked_mean(policy_loss, mask)
|
|
2479
2565
|
|
|
2480
2566
|
# handle entropy loss for naive exploration bonus
|
|
2481
2567
|
|
|
2482
2568
|
entropy_loss = - reduce(entropies, 'b t na -> b t', 'sum')
|
|
2483
2569
|
|
|
2570
|
+
entropy_loss = masked_mean(entropy_loss, mask)
|
|
2571
|
+
|
|
2572
|
+
# total policy loss
|
|
2573
|
+
|
|
2484
2574
|
total_policy_loss = (
|
|
2485
2575
|
policy_loss +
|
|
2486
2576
|
entropy_loss * self.policy_entropy_weight
|
|
2487
2577
|
)
|
|
2488
2578
|
|
|
2489
|
-
# maybe handle variable lengths
|
|
2490
|
-
|
|
2491
|
-
if is_var_len:
|
|
2492
|
-
total_policy_loss = total_policy_loss[mask].mean()
|
|
2493
|
-
else:
|
|
2494
|
-
total_policy_loss = total_policy_loss.mean()
|
|
2495
|
-
|
|
2496
2579
|
# maybe take policy optimizer step
|
|
2497
2580
|
|
|
2498
2581
|
if exists(policy_optim):
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
dreamer4/__init__.py,sha256=Jssh1obzDRtTfBLZl36kXge1cIQlMjf_8DyjPulvKSk,183
|
|
2
|
+
dreamer4/dreamer4.py,sha256=8RIgxdlgphe19w4bdcy2w7qWqVL4CZtoPfL6LXwj7jE,116380
|
|
3
|
+
dreamer4/mocks.py,sha256=TfqOB_Gq6N_GggBYwa6ZAJQx38ntlYbXZe23Ne4jshw,2502
|
|
4
|
+
dreamer4/trainers.py,sha256=D2b7WTgTHElLhIWLFgl2Ct2knGJLTk91HHpC5UkNvG0,14028
|
|
5
|
+
dreamer4-0.0.92.dist-info/METADATA,sha256=we70JHu-GvXjlcb5HR2bgtn8dc6L-uOARiNdqAKqDTU,3065
|
|
6
|
+
dreamer4-0.0.92.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
+
dreamer4-0.0.92.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
+
dreamer4-0.0.92.dist-info/RECORD,,
|
dreamer4-0.0.90.dist-info/RECORD
DELETED
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
dreamer4/__init__.py,sha256=Jssh1obzDRtTfBLZl36kXge1cIQlMjf_8DyjPulvKSk,183
|
|
2
|
-
dreamer4/dreamer4.py,sha256=Ig-t_A8BJWY2eKhsees4_zGXzvtS2JTQTlRuS33ufT8,113812
|
|
3
|
-
dreamer4/mocks.py,sha256=TfqOB_Gq6N_GggBYwa6ZAJQx38ntlYbXZe23Ne4jshw,2502
|
|
4
|
-
dreamer4/trainers.py,sha256=D2b7WTgTHElLhIWLFgl2Ct2knGJLTk91HHpC5UkNvG0,14028
|
|
5
|
-
dreamer4-0.0.90.dist-info/METADATA,sha256=94VfjlhIE6dDY5AbipuRF-Ip7pyhvgQOC4EBKc8ZKRg,3065
|
|
6
|
-
dreamer4-0.0.90.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
-
dreamer4-0.0.90.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
-
dreamer4-0.0.90.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|