dreamer4 0.0.92__py3-none-any.whl → 0.0.93__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dreamer4/dreamer4.py CHANGED
@@ -81,6 +81,7 @@ class Experience:
81
81
  rewards: Tensor | None = None
82
82
  actions: tuple[Tensor, Tensor] | None = None
83
83
  log_probs: tuple[Tensor, Tensor] | None = None
84
+ old_action_unembeds: tuple[Tensor, Tensor] | None = None
84
85
  values: Tensor | None = None
85
86
  step_size: int | None = None
86
87
  lens: Tensor | None = None
@@ -1887,6 +1888,7 @@ class DynamicsWorldModel(Module):
1887
1888
  gae_lambda = 0.95,
1888
1889
  ppo_eps_clip = 0.2,
1889
1890
  pmpo_pos_to_neg_weight = 0.5, # pos and neg equal weight
1891
+ pmpo_kl_div_loss_weight = 1.,
1890
1892
  value_clip = 0.4,
1891
1893
  policy_entropy_weight = .01,
1892
1894
  gae_use_accelerated = False
@@ -2093,6 +2095,7 @@ class DynamicsWorldModel(Module):
2093
2095
  # pmpo related
2094
2096
 
2095
2097
  self.pmpo_pos_to_neg_weight = pmpo_pos_to_neg_weight
2098
+ self.pmpo_kl_div_loss_weight = pmpo_kl_div_loss_weight
2096
2099
 
2097
2100
  # rewards related
2098
2101
 
@@ -2222,7 +2225,8 @@ class DynamicsWorldModel(Module):
2222
2225
  max_timesteps = 16,
2223
2226
  env_is_vectorized = False,
2224
2227
  use_time_kv_cache = True,
2225
- store_agent_embed = False
2228
+ store_agent_embed = False,
2229
+ store_old_action_unembeds = False,
2226
2230
  ):
2227
2231
  assert exists(self.video_tokenizer)
2228
2232
 
@@ -2248,6 +2252,7 @@ class DynamicsWorldModel(Module):
2248
2252
  latents = None
2249
2253
 
2250
2254
  acc_agent_embed = None
2255
+ acc_policy_embed = None
2251
2256
 
2252
2257
  # keep track of termination, for setting the `is_truncated` field on Experience and for early stopping interaction with env
2253
2258
 
@@ -2300,6 +2305,9 @@ class DynamicsWorldModel(Module):
2300
2305
 
2301
2306
  policy_embed = self.policy_head(one_agent_embed)
2302
2307
 
2308
+ if store_old_action_unembeds:
2309
+ acc_policy_embed = safe_cat((acc_policy_embed, policy_embed), dim = 1)
2310
+
2303
2311
  # sample actions
2304
2312
 
2305
2313
  sampled_discrete_actions, sampled_continuous_actions = self.action_embedder.sample(policy_embed, pred_head_index = 0, squeeze = True)
@@ -2383,6 +2391,7 @@ class DynamicsWorldModel(Module):
2383
2391
  actions = (discrete_actions, continuous_actions),
2384
2392
  log_probs = (discrete_log_probs, continuous_log_probs),
2385
2393
  values = values,
2394
+ old_action_unembeds = self.action_embedder.unembed(acc_policy_embed, pred_head_index = 0) if store_old_action_unembeds else None,
2386
2395
  agent_embed = acc_agent_embed if store_agent_embed else None,
2387
2396
  step_size = step_size,
2388
2397
  agent_index = agent_index,
@@ -2411,6 +2420,7 @@ class DynamicsWorldModel(Module):
2411
2420
  old_values = experience.values
2412
2421
  rewards = experience.rewards
2413
2422
  agent_embeds = experience.agent_embed
2423
+ old_action_unembeds = experience.old_action_unembeds
2414
2424
 
2415
2425
  step_size = experience.step_size
2416
2426
  agent_index = experience.agent_index
@@ -2489,6 +2499,7 @@ class DynamicsWorldModel(Module):
2489
2499
  if use_pmpo:
2490
2500
  pos_advantage_mask = advantage >= 0.
2491
2501
  neg_advantage_mask = ~pos_advantage_mask
2502
+
2492
2503
  else:
2493
2504
  advantage = F.layer_norm(advantage, advantage.shape, eps = eps)
2494
2505
 
@@ -2552,6 +2563,25 @@ class DynamicsWorldModel(Module):
2552
2563
 
2553
2564
  policy_loss = -(α * pos + (1. - α) * neg)
2554
2565
 
2566
+ # take care of kl
2567
+
2568
+ if self.pmpo_kl_div_loss_weight > 0.:
2569
+ new_unembedded_actions = self.action_embedder.unembed(policy_embed, pred_head_index = 0)
2570
+
2571
+ discrete_kl_div, continuous_kl_div = self.action_embedder.kl_div(new_unembedded_actions, old_action_unembeds)
2572
+
2573
+ # accumulate discrete and continuous kl div
2574
+
2575
+ kl_div_loss = 0.
2576
+
2577
+ if exists(discrete_kl_div):
2578
+ kl_div_loss = kl_div_loss + discrete_kl_div[mask].mean()
2579
+
2580
+ if exists(continuous_kl_div):
2581
+ kl_div_loss = kl_div_loss + continuous_kl_div[mask].mean()
2582
+
2583
+ policy_loss = policy_loss + kl_div_loss * self.pmpo_kl_div_loss_weight
2584
+
2555
2585
  else:
2556
2586
  # ppo clipped surrogate loss
2557
2587
 
@@ -2694,6 +2724,10 @@ class DynamicsWorldModel(Module):
2694
2724
 
2695
2725
  acc_agent_embed = None
2696
2726
 
2727
+ # maybe store old actions for kl
2728
+
2729
+ acc_policy_embed = None
2730
+
2697
2731
  # maybe return rewards
2698
2732
 
2699
2733
  decoded_rewards = None
@@ -2818,6 +2852,13 @@ class DynamicsWorldModel(Module):
2818
2852
 
2819
2853
  policy_embed = self.policy_head(one_agent_embed)
2820
2854
 
2855
+ # maybe store old actions
2856
+
2857
+ if store_old_action_unembeds:
2858
+ acc_policy_embed = safe_cat((acc_policy_embed, policy_embed))
2859
+
2860
+ # sample actions
2861
+
2821
2862
  sampled_discrete_actions, sampled_continuous_actions = self.action_embedder.sample(policy_embed, pred_head_index = 0, squeeze = True)
2822
2863
 
2823
2864
  decoded_discrete_actions = safe_cat((decoded_discrete_actions, sampled_discrete_actions), dim = 1)
@@ -2906,6 +2947,7 @@ class DynamicsWorldModel(Module):
2906
2947
  video = video,
2907
2948
  proprio = proprio if has_proprio else None,
2908
2949
  agent_embed = acc_agent_embed if store_agent_embed else None,
2950
+ old_action_unembeds = self.action_embedder.unembed(acc_policy_embed, pred_head_index = 0) if store_old_action_unembeds else None,
2909
2951
  step_size = step_size,
2910
2952
  agent_index = agent_index,
2911
2953
  lens = experience_lens,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dreamer4
3
- Version: 0.0.92
3
+ Version: 0.0.93
4
4
  Summary: Dreamer 4
5
5
  Project-URL: Homepage, https://pypi.org/project/dreamer4/
6
6
  Project-URL: Repository, https://github.com/lucidrains/dreamer4
@@ -0,0 +1,8 @@
1
+ dreamer4/__init__.py,sha256=Jssh1obzDRtTfBLZl36kXge1cIQlMjf_8DyjPulvKSk,183
2
+ dreamer4/dreamer4.py,sha256=B2Bk6JJO9MVTWwss9hOP1k6SBiEr56ijNOa3PiidPnY,118120
3
+ dreamer4/mocks.py,sha256=TfqOB_Gq6N_GggBYwa6ZAJQx38ntlYbXZe23Ne4jshw,2502
4
+ dreamer4/trainers.py,sha256=D2b7WTgTHElLhIWLFgl2Ct2knGJLTk91HHpC5UkNvG0,14028
5
+ dreamer4-0.0.93.dist-info/METADATA,sha256=FhVnlhfeloUPMiFqqJ5qR6fqdd7YmN1-gXykkOTPF_A,3065
6
+ dreamer4-0.0.93.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ dreamer4-0.0.93.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ dreamer4-0.0.93.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- dreamer4/__init__.py,sha256=Jssh1obzDRtTfBLZl36kXge1cIQlMjf_8DyjPulvKSk,183
2
- dreamer4/dreamer4.py,sha256=8RIgxdlgphe19w4bdcy2w7qWqVL4CZtoPfL6LXwj7jE,116380
3
- dreamer4/mocks.py,sha256=TfqOB_Gq6N_GggBYwa6ZAJQx38ntlYbXZe23Ne4jshw,2502
4
- dreamer4/trainers.py,sha256=D2b7WTgTHElLhIWLFgl2Ct2knGJLTk91HHpC5UkNvG0,14028
5
- dreamer4-0.0.92.dist-info/METADATA,sha256=we70JHu-GvXjlcb5HR2bgtn8dc6L-uOARiNdqAKqDTU,3065
6
- dreamer4-0.0.92.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- dreamer4-0.0.92.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- dreamer4-0.0.92.dist-info/RECORD,,