dreamer4 0.0.91__tar.gz → 0.0.93__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dreamer4 might be problematic. Click here for more details.
- {dreamer4-0.0.91 → dreamer4-0.0.93}/PKG-INFO +1 -1
- {dreamer4-0.0.91 → dreamer4-0.0.93}/dreamer4/dreamer4.py +101 -6
- {dreamer4-0.0.91 → dreamer4-0.0.93}/pyproject.toml +1 -1
- {dreamer4-0.0.91 → dreamer4-0.0.93}/tests/test_dreamer.py +9 -0
- {dreamer4-0.0.91 → dreamer4-0.0.93}/.github/workflows/python-publish.yml +0 -0
- {dreamer4-0.0.91 → dreamer4-0.0.93}/.github/workflows/test.yml +0 -0
- {dreamer4-0.0.91 → dreamer4-0.0.93}/.gitignore +0 -0
- {dreamer4-0.0.91 → dreamer4-0.0.93}/LICENSE +0 -0
- {dreamer4-0.0.91 → dreamer4-0.0.93}/README.md +0 -0
- {dreamer4-0.0.91 → dreamer4-0.0.93}/dreamer4/__init__.py +0 -0
- {dreamer4-0.0.91 → dreamer4-0.0.93}/dreamer4/mocks.py +0 -0
- {dreamer4-0.0.91 → dreamer4-0.0.93}/dreamer4/trainers.py +0 -0
- {dreamer4-0.0.91 → dreamer4-0.0.93}/dreamer4-fig2.png +0 -0
|
@@ -11,7 +11,7 @@ from dataclasses import dataclass, asdict
|
|
|
11
11
|
import torch
|
|
12
12
|
import torch.nn.functional as F
|
|
13
13
|
from torch.nested import nested_tensor
|
|
14
|
-
from torch.distributions import Normal
|
|
14
|
+
from torch.distributions import Normal, kl
|
|
15
15
|
from torch.nn import Module, ModuleList, Embedding, Parameter, Sequential, Linear, RMSNorm, Identity
|
|
16
16
|
from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, full, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange
|
|
17
17
|
from torch.utils._pytree import tree_flatten, tree_unflatten
|
|
@@ -81,6 +81,7 @@ class Experience:
|
|
|
81
81
|
rewards: Tensor | None = None
|
|
82
82
|
actions: tuple[Tensor, Tensor] | None = None
|
|
83
83
|
log_probs: tuple[Tensor, Tensor] | None = None
|
|
84
|
+
old_action_unembeds: tuple[Tensor, Tensor] | None = None
|
|
84
85
|
values: Tensor | None = None
|
|
85
86
|
step_size: int | None = None
|
|
86
87
|
lens: Tensor | None = None
|
|
@@ -198,6 +199,14 @@ def masked_mean(t, mask = None):
|
|
|
198
199
|
def log(t, eps = 1e-20):
|
|
199
200
|
return t.clamp(min = eps).log()
|
|
200
201
|
|
|
202
|
+
def mean_log_var_to_distr(
|
|
203
|
+
mean_log_var: Tensor
|
|
204
|
+
) -> Normal:
|
|
205
|
+
|
|
206
|
+
mean, log_var = mean_log_var.unbind(dim = -1)
|
|
207
|
+
std = (0.5 * log_var).exp()
|
|
208
|
+
return Normal(mean, std)
|
|
209
|
+
|
|
201
210
|
def safe_cat(tensors, dim):
|
|
202
211
|
tensors = [*filter(exists, tensors)]
|
|
203
212
|
|
|
@@ -824,10 +833,7 @@ class ActionEmbedder(Module):
|
|
|
824
833
|
continuous_entropies = None
|
|
825
834
|
|
|
826
835
|
if exists(continuous_targets):
|
|
827
|
-
|
|
828
|
-
std = (0.5 * log_var).exp()
|
|
829
|
-
|
|
830
|
-
distr = Normal(mean, std)
|
|
836
|
+
distr = mean_log_var_to_distr(continuous_action_mean_log_var)
|
|
831
837
|
continuous_log_probs = distr.log_prob(continuous_targets)
|
|
832
838
|
|
|
833
839
|
if return_entropies:
|
|
@@ -842,6 +848,54 @@ class ActionEmbedder(Module):
|
|
|
842
848
|
|
|
843
849
|
return log_probs, entropies
|
|
844
850
|
|
|
851
|
+
def kl_div(
|
|
852
|
+
self,
|
|
853
|
+
src: tuple[Tensor | None, Tensor | None],
|
|
854
|
+
tgt: tuple[Tensor | None, Tensor | None]
|
|
855
|
+
) -> tuple[Tensor | None, Tensor | None]:
|
|
856
|
+
|
|
857
|
+
src_discrete, src_continuous = src
|
|
858
|
+
tgt_discrete, tgt_continuous = tgt
|
|
859
|
+
|
|
860
|
+
discrete_kl_div = None
|
|
861
|
+
|
|
862
|
+
# split discrete if it is not already (multiple discrete actions)
|
|
863
|
+
|
|
864
|
+
if exists(src_discrete):
|
|
865
|
+
|
|
866
|
+
discrete_split = self.num_discrete_actions.tolist()
|
|
867
|
+
|
|
868
|
+
if is_tensor(src_discrete):
|
|
869
|
+
src_discrete = src_discrete.split(discrete_split, dim = -1)
|
|
870
|
+
|
|
871
|
+
if is_tensor(tgt_discrete):
|
|
872
|
+
tgt_discrete = tgt_discrete.split(discrete_split, dim = -1)
|
|
873
|
+
|
|
874
|
+
discrete_kl_divs = []
|
|
875
|
+
|
|
876
|
+
for src_logit, tgt_logit in zip(src_discrete, tgt_discrete):
|
|
877
|
+
|
|
878
|
+
src_log_probs = src_logit.log_softmax(dim = -1)
|
|
879
|
+
tgt_prob = tgt_logit.softmax(dim = -1)
|
|
880
|
+
|
|
881
|
+
one_discrete_kl_div = F.kl_div(src_log_probs, tgt_prob, reduction = 'none')
|
|
882
|
+
|
|
883
|
+
discrete_kl_divs.append(one_discrete_kl_div.sum(dim = -1))
|
|
884
|
+
|
|
885
|
+
discrete_kl_div = stack(discrete_kl_divs, dim = -1)
|
|
886
|
+
|
|
887
|
+
# calculate kl divergence for continuous
|
|
888
|
+
|
|
889
|
+
continuous_kl_div = None
|
|
890
|
+
|
|
891
|
+
if exists(src_continuous):
|
|
892
|
+
src_normal = mean_log_var_to_distr(src_continuous)
|
|
893
|
+
tgt_normal = mean_log_var_to_distr(tgt_continuous)
|
|
894
|
+
|
|
895
|
+
continuous_kl_div = kl.kl_divergence(src_normal, tgt_normal)
|
|
896
|
+
|
|
897
|
+
return discrete_kl_div, continuous_kl_div
|
|
898
|
+
|
|
845
899
|
def forward(
|
|
846
900
|
self,
|
|
847
901
|
*,
|
|
@@ -1834,6 +1888,7 @@ class DynamicsWorldModel(Module):
|
|
|
1834
1888
|
gae_lambda = 0.95,
|
|
1835
1889
|
ppo_eps_clip = 0.2,
|
|
1836
1890
|
pmpo_pos_to_neg_weight = 0.5, # pos and neg equal weight
|
|
1891
|
+
pmpo_kl_div_loss_weight = 1.,
|
|
1837
1892
|
value_clip = 0.4,
|
|
1838
1893
|
policy_entropy_weight = .01,
|
|
1839
1894
|
gae_use_accelerated = False
|
|
@@ -2040,6 +2095,7 @@ class DynamicsWorldModel(Module):
|
|
|
2040
2095
|
# pmpo related
|
|
2041
2096
|
|
|
2042
2097
|
self.pmpo_pos_to_neg_weight = pmpo_pos_to_neg_weight
|
|
2098
|
+
self.pmpo_kl_div_loss_weight = pmpo_kl_div_loss_weight
|
|
2043
2099
|
|
|
2044
2100
|
# rewards related
|
|
2045
2101
|
|
|
@@ -2169,7 +2225,8 @@ class DynamicsWorldModel(Module):
|
|
|
2169
2225
|
max_timesteps = 16,
|
|
2170
2226
|
env_is_vectorized = False,
|
|
2171
2227
|
use_time_kv_cache = True,
|
|
2172
|
-
store_agent_embed = False
|
|
2228
|
+
store_agent_embed = False,
|
|
2229
|
+
store_old_action_unembeds = False,
|
|
2173
2230
|
):
|
|
2174
2231
|
assert exists(self.video_tokenizer)
|
|
2175
2232
|
|
|
@@ -2195,6 +2252,7 @@ class DynamicsWorldModel(Module):
|
|
|
2195
2252
|
latents = None
|
|
2196
2253
|
|
|
2197
2254
|
acc_agent_embed = None
|
|
2255
|
+
acc_policy_embed = None
|
|
2198
2256
|
|
|
2199
2257
|
# keep track of termination, for setting the `is_truncated` field on Experience and for early stopping interaction with env
|
|
2200
2258
|
|
|
@@ -2247,6 +2305,9 @@ class DynamicsWorldModel(Module):
|
|
|
2247
2305
|
|
|
2248
2306
|
policy_embed = self.policy_head(one_agent_embed)
|
|
2249
2307
|
|
|
2308
|
+
if store_old_action_unembeds:
|
|
2309
|
+
acc_policy_embed = safe_cat((acc_policy_embed, policy_embed), dim = 1)
|
|
2310
|
+
|
|
2250
2311
|
# sample actions
|
|
2251
2312
|
|
|
2252
2313
|
sampled_discrete_actions, sampled_continuous_actions = self.action_embedder.sample(policy_embed, pred_head_index = 0, squeeze = True)
|
|
@@ -2330,6 +2391,7 @@ class DynamicsWorldModel(Module):
|
|
|
2330
2391
|
actions = (discrete_actions, continuous_actions),
|
|
2331
2392
|
log_probs = (discrete_log_probs, continuous_log_probs),
|
|
2332
2393
|
values = values,
|
|
2394
|
+
old_action_unembeds = self.action_embedder.unembed(acc_policy_embed, pred_head_index = 0) if store_old_action_unembeds else None,
|
|
2333
2395
|
agent_embed = acc_agent_embed if store_agent_embed else None,
|
|
2334
2396
|
step_size = step_size,
|
|
2335
2397
|
agent_index = agent_index,
|
|
@@ -2358,6 +2420,7 @@ class DynamicsWorldModel(Module):
|
|
|
2358
2420
|
old_values = experience.values
|
|
2359
2421
|
rewards = experience.rewards
|
|
2360
2422
|
agent_embeds = experience.agent_embed
|
|
2423
|
+
old_action_unembeds = experience.old_action_unembeds
|
|
2361
2424
|
|
|
2362
2425
|
step_size = experience.step_size
|
|
2363
2426
|
agent_index = experience.agent_index
|
|
@@ -2436,6 +2499,7 @@ class DynamicsWorldModel(Module):
|
|
|
2436
2499
|
if use_pmpo:
|
|
2437
2500
|
pos_advantage_mask = advantage >= 0.
|
|
2438
2501
|
neg_advantage_mask = ~pos_advantage_mask
|
|
2502
|
+
|
|
2439
2503
|
else:
|
|
2440
2504
|
advantage = F.layer_norm(advantage, advantage.shape, eps = eps)
|
|
2441
2505
|
|
|
@@ -2499,6 +2563,25 @@ class DynamicsWorldModel(Module):
|
|
|
2499
2563
|
|
|
2500
2564
|
policy_loss = -(α * pos + (1. - α) * neg)
|
|
2501
2565
|
|
|
2566
|
+
# take care of kl
|
|
2567
|
+
|
|
2568
|
+
if self.pmpo_kl_div_loss_weight > 0.:
|
|
2569
|
+
new_unembedded_actions = self.action_embedder.unembed(policy_embed, pred_head_index = 0)
|
|
2570
|
+
|
|
2571
|
+
discrete_kl_div, continuous_kl_div = self.action_embedder.kl_div(new_unembedded_actions, old_action_unembeds)
|
|
2572
|
+
|
|
2573
|
+
# accumulate discrete and continuous kl div
|
|
2574
|
+
|
|
2575
|
+
kl_div_loss = 0.
|
|
2576
|
+
|
|
2577
|
+
if exists(discrete_kl_div):
|
|
2578
|
+
kl_div_loss = kl_div_loss + discrete_kl_div[mask].mean()
|
|
2579
|
+
|
|
2580
|
+
if exists(continuous_kl_div):
|
|
2581
|
+
kl_div_loss = kl_div_loss + continuous_kl_div[mask].mean()
|
|
2582
|
+
|
|
2583
|
+
policy_loss = policy_loss + kl_div_loss * self.pmpo_kl_div_loss_weight
|
|
2584
|
+
|
|
2502
2585
|
else:
|
|
2503
2586
|
# ppo clipped surrogate loss
|
|
2504
2587
|
|
|
@@ -2641,6 +2724,10 @@ class DynamicsWorldModel(Module):
|
|
|
2641
2724
|
|
|
2642
2725
|
acc_agent_embed = None
|
|
2643
2726
|
|
|
2727
|
+
# maybe store old actions for kl
|
|
2728
|
+
|
|
2729
|
+
acc_policy_embed = None
|
|
2730
|
+
|
|
2644
2731
|
# maybe return rewards
|
|
2645
2732
|
|
|
2646
2733
|
decoded_rewards = None
|
|
@@ -2765,6 +2852,13 @@ class DynamicsWorldModel(Module):
|
|
|
2765
2852
|
|
|
2766
2853
|
policy_embed = self.policy_head(one_agent_embed)
|
|
2767
2854
|
|
|
2855
|
+
# maybe store old actions
|
|
2856
|
+
|
|
2857
|
+
if store_old_action_unembeds:
|
|
2858
|
+
acc_policy_embed = safe_cat((acc_policy_embed, policy_embed))
|
|
2859
|
+
|
|
2860
|
+
# sample actions
|
|
2861
|
+
|
|
2768
2862
|
sampled_discrete_actions, sampled_continuous_actions = self.action_embedder.sample(policy_embed, pred_head_index = 0, squeeze = True)
|
|
2769
2863
|
|
|
2770
2864
|
decoded_discrete_actions = safe_cat((decoded_discrete_actions, sampled_discrete_actions), dim = 1)
|
|
@@ -2853,6 +2947,7 @@ class DynamicsWorldModel(Module):
|
|
|
2853
2947
|
video = video,
|
|
2854
2948
|
proprio = proprio if has_proprio else None,
|
|
2855
2949
|
agent_embed = acc_agent_embed if store_agent_embed else None,
|
|
2950
|
+
old_action_unembeds = self.action_embedder.unembed(acc_policy_embed, pred_head_index = 0) if store_old_action_unembeds else None,
|
|
2856
2951
|
step_size = step_size,
|
|
2857
2952
|
agent_index = agent_index,
|
|
2858
2953
|
lens = experience_lens,
|
|
@@ -346,6 +346,15 @@ def test_action_embedder():
|
|
|
346
346
|
assert discrete_logits.shape == (2, 3, 8)
|
|
347
347
|
assert continuous_mean_log_var.shape == (2, 3, 2, 2)
|
|
348
348
|
|
|
349
|
+
# test kl div
|
|
350
|
+
|
|
351
|
+
discrete_logits_tgt, continuous_mean_log_var_tgt = embedder.unembed(action_embed)
|
|
352
|
+
|
|
353
|
+
discrete_kl_div, continuous_kl_div = embedder.kl_div((discrete_logits, continuous_mean_log_var), (discrete_logits_tgt, continuous_mean_log_var_tgt))
|
|
354
|
+
|
|
355
|
+
assert discrete_kl_div.shape == (2, 3, 2)
|
|
356
|
+
assert continuous_kl_div.shape == (2, 3, 2)
|
|
357
|
+
|
|
349
358
|
# return discrete split by number of actions
|
|
350
359
|
|
|
351
360
|
discrete_logits, continuous_mean_log_var = embedder.unembed(action_embed, return_split_discrete = True)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|