dreamer4 0.0.92__tar.gz → 0.0.94__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dreamer4 might be problematic. Click here for more details.

@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dreamer4
3
- Version: 0.0.92
3
+ Version: 0.0.94
4
4
  Summary: Dreamer 4
5
5
  Project-URL: Homepage, https://pypi.org/project/dreamer4/
6
6
  Project-URL: Repository, https://github.com/lucidrains/dreamer4
@@ -81,6 +81,7 @@ class Experience:
81
81
  rewards: Tensor | None = None
82
82
  actions: tuple[Tensor, Tensor] | None = None
83
83
  log_probs: tuple[Tensor, Tensor] | None = None
84
+ old_action_unembeds: tuple[Tensor, Tensor] | None = None
84
85
  values: Tensor | None = None
85
86
  step_size: int | None = None
86
87
  lens: Tensor | None = None
@@ -1887,6 +1888,7 @@ class DynamicsWorldModel(Module):
1887
1888
  gae_lambda = 0.95,
1888
1889
  ppo_eps_clip = 0.2,
1889
1890
  pmpo_pos_to_neg_weight = 0.5, # pos and neg equal weight
1891
+ pmpo_kl_div_loss_weight = 1.,
1890
1892
  value_clip = 0.4,
1891
1893
  policy_entropy_weight = .01,
1892
1894
  gae_use_accelerated = False
@@ -2093,6 +2095,7 @@ class DynamicsWorldModel(Module):
2093
2095
  # pmpo related
2094
2096
 
2095
2097
  self.pmpo_pos_to_neg_weight = pmpo_pos_to_neg_weight
2098
+ self.pmpo_kl_div_loss_weight = pmpo_kl_div_loss_weight
2096
2099
 
2097
2100
  # rewards related
2098
2101
 
@@ -2222,7 +2225,8 @@ class DynamicsWorldModel(Module):
2222
2225
  max_timesteps = 16,
2223
2226
  env_is_vectorized = False,
2224
2227
  use_time_kv_cache = True,
2225
- store_agent_embed = False
2228
+ store_agent_embed = True,
2229
+ store_old_action_unembeds = True,
2226
2230
  ):
2227
2231
  assert exists(self.video_tokenizer)
2228
2232
 
@@ -2248,6 +2252,7 @@ class DynamicsWorldModel(Module):
2248
2252
  latents = None
2249
2253
 
2250
2254
  acc_agent_embed = None
2255
+ acc_policy_embed = None
2251
2256
 
2252
2257
  # keep track of termination, for setting the `is_truncated` field on Experience and for early stopping interaction with env
2253
2258
 
@@ -2300,6 +2305,9 @@ class DynamicsWorldModel(Module):
2300
2305
 
2301
2306
  policy_embed = self.policy_head(one_agent_embed)
2302
2307
 
2308
+ if store_old_action_unembeds:
2309
+ acc_policy_embed = safe_cat((acc_policy_embed, policy_embed), dim = 1)
2310
+
2303
2311
  # sample actions
2304
2312
 
2305
2313
  sampled_discrete_actions, sampled_continuous_actions = self.action_embedder.sample(policy_embed, pred_head_index = 0, squeeze = True)
@@ -2383,6 +2391,7 @@ class DynamicsWorldModel(Module):
2383
2391
  actions = (discrete_actions, continuous_actions),
2384
2392
  log_probs = (discrete_log_probs, continuous_log_probs),
2385
2393
  values = values,
2394
+ old_action_unembeds = self.action_embedder.unembed(acc_policy_embed, pred_head_index = 0) if exists(acc_policy_embed) and store_old_action_unembeds else None,
2386
2395
  agent_embed = acc_agent_embed if store_agent_embed else None,
2387
2396
  step_size = step_size,
2388
2397
  agent_index = agent_index,
@@ -2411,6 +2420,7 @@ class DynamicsWorldModel(Module):
2411
2420
  old_values = experience.values
2412
2421
  rewards = experience.rewards
2413
2422
  agent_embeds = experience.agent_embed
2423
+ old_action_unembeds = experience.old_action_unembeds
2414
2424
 
2415
2425
  step_size = experience.step_size
2416
2426
  agent_index = experience.agent_index
@@ -2489,6 +2499,7 @@ class DynamicsWorldModel(Module):
2489
2499
  if use_pmpo:
2490
2500
  pos_advantage_mask = advantage >= 0.
2491
2501
  neg_advantage_mask = ~pos_advantage_mask
2502
+
2492
2503
  else:
2493
2504
  advantage = F.layer_norm(advantage, advantage.shape, eps = eps)
2494
2505
 
@@ -2552,6 +2563,25 @@ class DynamicsWorldModel(Module):
2552
2563
 
2553
2564
  policy_loss = -(α * pos + (1. - α) * neg)
2554
2565
 
2566
+ # take care of kl
2567
+
2568
+ if self.pmpo_kl_div_loss_weight > 0.:
2569
+ new_unembedded_actions = self.action_embedder.unembed(policy_embed, pred_head_index = 0)
2570
+
2571
+ discrete_kl_div, continuous_kl_div = self.action_embedder.kl_div(new_unembedded_actions, old_action_unembeds)
2572
+
2573
+ # accumulate discrete and continuous kl div
2574
+
2575
+ kl_div_loss = 0.
2576
+
2577
+ if exists(discrete_kl_div):
2578
+ kl_div_loss = kl_div_loss + discrete_kl_div[mask].mean()
2579
+
2580
+ if exists(continuous_kl_div):
2581
+ kl_div_loss = kl_div_loss + continuous_kl_div[mask].mean()
2582
+
2583
+ policy_loss = policy_loss + kl_div_loss * self.pmpo_kl_div_loss_weight
2584
+
2555
2585
  else:
2556
2586
  # ppo clipped surrogate loss
2557
2587
 
@@ -2637,7 +2667,8 @@ class DynamicsWorldModel(Module):
2637
2667
  return_agent_actions = False,
2638
2668
  return_log_probs_and_values = False,
2639
2669
  return_time_kv_cache = False,
2640
- store_agent_embed = False
2670
+ store_agent_embed = True,
2671
+ store_old_action_unembeds = True
2641
2672
 
2642
2673
  ): # (b t n d) | (b c t h w)
2643
2674
 
@@ -2694,6 +2725,10 @@ class DynamicsWorldModel(Module):
2694
2725
 
2695
2726
  acc_agent_embed = None
2696
2727
 
2728
+ # maybe store old actions for kl
2729
+
2730
+ acc_policy_embed = None
2731
+
2697
2732
  # maybe return rewards
2698
2733
 
2699
2734
  decoded_rewards = None
@@ -2818,6 +2853,13 @@ class DynamicsWorldModel(Module):
2818
2853
 
2819
2854
  policy_embed = self.policy_head(one_agent_embed)
2820
2855
 
2856
+ # maybe store old actions
2857
+
2858
+ if store_old_action_unembeds:
2859
+ acc_policy_embed = safe_cat((acc_policy_embed, policy_embed))
2860
+
2861
+ # sample actions
2862
+
2821
2863
  sampled_discrete_actions, sampled_continuous_actions = self.action_embedder.sample(policy_embed, pred_head_index = 0, squeeze = True)
2822
2864
 
2823
2865
  decoded_discrete_actions = safe_cat((decoded_discrete_actions, sampled_discrete_actions), dim = 1)
@@ -2906,6 +2948,7 @@ class DynamicsWorldModel(Module):
2906
2948
  video = video,
2907
2949
  proprio = proprio if has_proprio else None,
2908
2950
  agent_embed = acc_agent_embed if store_agent_embed else None,
2951
+ old_action_unembeds = self.action_embedder.unembed(acc_policy_embed, pred_head_index = 0) if exists(acc_policy_embed) and store_old_action_unembeds else None,
2909
2952
  step_size = step_size,
2910
2953
  agent_index = agent_index,
2911
2954
  lens = experience_lens,
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "dreamer4"
3
- version = "0.0.92"
3
+ version = "0.0.94"
4
4
  description = "Dreamer 4"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
File without changes
File without changes
File without changes
File without changes
File without changes