dreamer4 0.0.91__tar.gz → 0.0.93__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dreamer4 might be problematic. Click here for more details.

@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dreamer4
3
- Version: 0.0.91
3
+ Version: 0.0.93
4
4
  Summary: Dreamer 4
5
5
  Project-URL: Homepage, https://pypi.org/project/dreamer4/
6
6
  Project-URL: Repository, https://github.com/lucidrains/dreamer4
@@ -11,7 +11,7 @@ from dataclasses import dataclass, asdict
11
11
  import torch
12
12
  import torch.nn.functional as F
13
13
  from torch.nested import nested_tensor
14
- from torch.distributions import Normal
14
+ from torch.distributions import Normal, kl
15
15
  from torch.nn import Module, ModuleList, Embedding, Parameter, Sequential, Linear, RMSNorm, Identity
16
16
  from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, full, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange
17
17
  from torch.utils._pytree import tree_flatten, tree_unflatten
@@ -81,6 +81,7 @@ class Experience:
81
81
  rewards: Tensor | None = None
82
82
  actions: tuple[Tensor, Tensor] | None = None
83
83
  log_probs: tuple[Tensor, Tensor] | None = None
84
+ old_action_unembeds: tuple[Tensor, Tensor] | None = None
84
85
  values: Tensor | None = None
85
86
  step_size: int | None = None
86
87
  lens: Tensor | None = None
@@ -198,6 +199,14 @@ def masked_mean(t, mask = None):
198
199
  def log(t, eps = 1e-20):
199
200
  return t.clamp(min = eps).log()
200
201
 
202
+ def mean_log_var_to_distr(
203
+ mean_log_var: Tensor
204
+ ) -> Normal:
205
+
206
+ mean, log_var = mean_log_var.unbind(dim = -1)
207
+ std = (0.5 * log_var).exp()
208
+ return Normal(mean, std)
209
+
201
210
  def safe_cat(tensors, dim):
202
211
  tensors = [*filter(exists, tensors)]
203
212
 
@@ -824,10 +833,7 @@ class ActionEmbedder(Module):
824
833
  continuous_entropies = None
825
834
 
826
835
  if exists(continuous_targets):
827
- mean, log_var = continuous_action_mean_log_var.unbind(dim = -1)
828
- std = (0.5 * log_var).exp()
829
-
830
- distr = Normal(mean, std)
836
+ distr = mean_log_var_to_distr(continuous_action_mean_log_var)
831
837
  continuous_log_probs = distr.log_prob(continuous_targets)
832
838
 
833
839
  if return_entropies:
@@ -842,6 +848,54 @@ class ActionEmbedder(Module):
842
848
 
843
849
  return log_probs, entropies
844
850
 
851
+ def kl_div(
852
+ self,
853
+ src: tuple[Tensor | None, Tensor | None],
854
+ tgt: tuple[Tensor | None, Tensor | None]
855
+ ) -> tuple[Tensor | None, Tensor | None]:
856
+
857
+ src_discrete, src_continuous = src
858
+ tgt_discrete, tgt_continuous = tgt
859
+
860
+ discrete_kl_div = None
861
+
862
+ # split discrete if it is not already (multiple discrete actions)
863
+
864
+ if exists(src_discrete):
865
+
866
+ discrete_split = self.num_discrete_actions.tolist()
867
+
868
+ if is_tensor(src_discrete):
869
+ src_discrete = src_discrete.split(discrete_split, dim = -1)
870
+
871
+ if is_tensor(tgt_discrete):
872
+ tgt_discrete = tgt_discrete.split(discrete_split, dim = -1)
873
+
874
+ discrete_kl_divs = []
875
+
876
+ for src_logit, tgt_logit in zip(src_discrete, tgt_discrete):
877
+
878
+ src_log_probs = src_logit.log_softmax(dim = -1)
879
+ tgt_prob = tgt_logit.softmax(dim = -1)
880
+
881
+ one_discrete_kl_div = F.kl_div(src_log_probs, tgt_prob, reduction = 'none')
882
+
883
+ discrete_kl_divs.append(one_discrete_kl_div.sum(dim = -1))
884
+
885
+ discrete_kl_div = stack(discrete_kl_divs, dim = -1)
886
+
887
+ # calculate kl divergence for continuous
888
+
889
+ continuous_kl_div = None
890
+
891
+ if exists(src_continuous):
892
+ src_normal = mean_log_var_to_distr(src_continuous)
893
+ tgt_normal = mean_log_var_to_distr(tgt_continuous)
894
+
895
+ continuous_kl_div = kl.kl_divergence(src_normal, tgt_normal)
896
+
897
+ return discrete_kl_div, continuous_kl_div
898
+
845
899
  def forward(
846
900
  self,
847
901
  *,
@@ -1834,6 +1888,7 @@ class DynamicsWorldModel(Module):
1834
1888
  gae_lambda = 0.95,
1835
1889
  ppo_eps_clip = 0.2,
1836
1890
  pmpo_pos_to_neg_weight = 0.5, # pos and neg equal weight
1891
+ pmpo_kl_div_loss_weight = 1.,
1837
1892
  value_clip = 0.4,
1838
1893
  policy_entropy_weight = .01,
1839
1894
  gae_use_accelerated = False
@@ -2040,6 +2095,7 @@ class DynamicsWorldModel(Module):
2040
2095
  # pmpo related
2041
2096
 
2042
2097
  self.pmpo_pos_to_neg_weight = pmpo_pos_to_neg_weight
2098
+ self.pmpo_kl_div_loss_weight = pmpo_kl_div_loss_weight
2043
2099
 
2044
2100
  # rewards related
2045
2101
 
@@ -2169,7 +2225,8 @@ class DynamicsWorldModel(Module):
2169
2225
  max_timesteps = 16,
2170
2226
  env_is_vectorized = False,
2171
2227
  use_time_kv_cache = True,
2172
- store_agent_embed = False
2228
+ store_agent_embed = False,
2229
+ store_old_action_unembeds = False,
2173
2230
  ):
2174
2231
  assert exists(self.video_tokenizer)
2175
2232
 
@@ -2195,6 +2252,7 @@ class DynamicsWorldModel(Module):
2195
2252
  latents = None
2196
2253
 
2197
2254
  acc_agent_embed = None
2255
+ acc_policy_embed = None
2198
2256
 
2199
2257
  # keep track of termination, for setting the `is_truncated` field on Experience and for early stopping interaction with env
2200
2258
 
@@ -2247,6 +2305,9 @@ class DynamicsWorldModel(Module):
2247
2305
 
2248
2306
  policy_embed = self.policy_head(one_agent_embed)
2249
2307
 
2308
+ if store_old_action_unembeds:
2309
+ acc_policy_embed = safe_cat((acc_policy_embed, policy_embed), dim = 1)
2310
+
2250
2311
  # sample actions
2251
2312
 
2252
2313
  sampled_discrete_actions, sampled_continuous_actions = self.action_embedder.sample(policy_embed, pred_head_index = 0, squeeze = True)
@@ -2330,6 +2391,7 @@ class DynamicsWorldModel(Module):
2330
2391
  actions = (discrete_actions, continuous_actions),
2331
2392
  log_probs = (discrete_log_probs, continuous_log_probs),
2332
2393
  values = values,
2394
+ old_action_unembeds = self.action_embedder.unembed(acc_policy_embed, pred_head_index = 0) if store_old_action_unembeds else None,
2333
2395
  agent_embed = acc_agent_embed if store_agent_embed else None,
2334
2396
  step_size = step_size,
2335
2397
  agent_index = agent_index,
@@ -2358,6 +2420,7 @@ class DynamicsWorldModel(Module):
2358
2420
  old_values = experience.values
2359
2421
  rewards = experience.rewards
2360
2422
  agent_embeds = experience.agent_embed
2423
+ old_action_unembeds = experience.old_action_unembeds
2361
2424
 
2362
2425
  step_size = experience.step_size
2363
2426
  agent_index = experience.agent_index
@@ -2436,6 +2499,7 @@ class DynamicsWorldModel(Module):
2436
2499
  if use_pmpo:
2437
2500
  pos_advantage_mask = advantage >= 0.
2438
2501
  neg_advantage_mask = ~pos_advantage_mask
2502
+
2439
2503
  else:
2440
2504
  advantage = F.layer_norm(advantage, advantage.shape, eps = eps)
2441
2505
 
@@ -2499,6 +2563,25 @@ class DynamicsWorldModel(Module):
2499
2563
 
2500
2564
  policy_loss = -(α * pos + (1. - α) * neg)
2501
2565
 
2566
+ # take care of kl
2567
+
2568
+ if self.pmpo_kl_div_loss_weight > 0.:
2569
+ new_unembedded_actions = self.action_embedder.unembed(policy_embed, pred_head_index = 0)
2570
+
2571
+ discrete_kl_div, continuous_kl_div = self.action_embedder.kl_div(new_unembedded_actions, old_action_unembeds)
2572
+
2573
+ # accumulate discrete and continuous kl div
2574
+
2575
+ kl_div_loss = 0.
2576
+
2577
+ if exists(discrete_kl_div):
2578
+ kl_div_loss = kl_div_loss + discrete_kl_div[mask].mean()
2579
+
2580
+ if exists(continuous_kl_div):
2581
+ kl_div_loss = kl_div_loss + continuous_kl_div[mask].mean()
2582
+
2583
+ policy_loss = policy_loss + kl_div_loss * self.pmpo_kl_div_loss_weight
2584
+
2502
2585
  else:
2503
2586
  # ppo clipped surrogate loss
2504
2587
 
@@ -2641,6 +2724,10 @@ class DynamicsWorldModel(Module):
2641
2724
 
2642
2725
  acc_agent_embed = None
2643
2726
 
2727
+ # maybe store old actions for kl
2728
+
2729
+ acc_policy_embed = None
2730
+
2644
2731
  # maybe return rewards
2645
2732
 
2646
2733
  decoded_rewards = None
@@ -2765,6 +2852,13 @@ class DynamicsWorldModel(Module):
2765
2852
 
2766
2853
  policy_embed = self.policy_head(one_agent_embed)
2767
2854
 
2855
+ # maybe store old actions
2856
+
2857
+ if store_old_action_unembeds:
2858
+ acc_policy_embed = safe_cat((acc_policy_embed, policy_embed))
2859
+
2860
+ # sample actions
2861
+
2768
2862
  sampled_discrete_actions, sampled_continuous_actions = self.action_embedder.sample(policy_embed, pred_head_index = 0, squeeze = True)
2769
2863
 
2770
2864
  decoded_discrete_actions = safe_cat((decoded_discrete_actions, sampled_discrete_actions), dim = 1)
@@ -2853,6 +2947,7 @@ class DynamicsWorldModel(Module):
2853
2947
  video = video,
2854
2948
  proprio = proprio if has_proprio else None,
2855
2949
  agent_embed = acc_agent_embed if store_agent_embed else None,
2950
+ old_action_unembeds = self.action_embedder.unembed(acc_policy_embed, pred_head_index = 0) if store_old_action_unembeds else None,
2856
2951
  step_size = step_size,
2857
2952
  agent_index = agent_index,
2858
2953
  lens = experience_lens,
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "dreamer4"
3
- version = "0.0.91"
3
+ version = "0.0.93"
4
4
  description = "Dreamer 4"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -346,6 +346,15 @@ def test_action_embedder():
346
346
  assert discrete_logits.shape == (2, 3, 8)
347
347
  assert continuous_mean_log_var.shape == (2, 3, 2, 2)
348
348
 
349
+ # test kl div
350
+
351
+ discrete_logits_tgt, continuous_mean_log_var_tgt = embedder.unembed(action_embed)
352
+
353
+ discrete_kl_div, continuous_kl_div = embedder.kl_div((discrete_logits, continuous_mean_log_var), (discrete_logits_tgt, continuous_mean_log_var_tgt))
354
+
355
+ assert discrete_kl_div.shape == (2, 3, 2)
356
+ assert continuous_kl_div.shape == (2, 3, 2)
357
+
349
358
  # return discrete split by number of actions
350
359
 
351
360
  discrete_logits, continuous_mean_log_var = embedder.unembed(action_embed, return_split_discrete = True)
File without changes
File without changes
File without changes
File without changes
File without changes