dreamer4 0.0.83__tar.gz → 0.0.85__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dreamer4 might be problematic. Click here for more details.

@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dreamer4
3
- Version: 0.0.83
3
+ Version: 0.0.85
4
4
  Summary: Dreamer 4
5
5
  Project-URL: Homepage, https://pypi.org/project/dreamer4/
6
6
  Project-URL: Repository, https://github.com/lucidrains/dreamer4
@@ -77,6 +77,7 @@ class Experience:
77
77
  latents: Tensor
78
78
  video: Tensor | None = None
79
79
  proprio: Tensor | None = None
80
+ agent_embed: Tensor | None = None,
80
81
  rewards: Tensor | None = None
81
82
  actions: tuple[Tensor, Tensor] | None = None
82
83
  log_probs: tuple[Tensor, Tensor] | None = None
@@ -2105,7 +2106,8 @@ class DynamicsWorldModel(Module):
2105
2106
  step_size = 4,
2106
2107
  max_timesteps = 16,
2107
2108
  env_is_vectorized = False,
2108
- use_time_kv_cache = True
2109
+ use_time_kv_cache = True,
2110
+ store_agent_embed = False
2109
2111
  ):
2110
2112
  assert exists(self.video_tokenizer)
2111
2113
 
@@ -2130,6 +2132,8 @@ class DynamicsWorldModel(Module):
2130
2132
  values = None
2131
2133
  latents = None
2132
2134
 
2135
+ acc_agent_embed = None
2136
+
2133
2137
  # keep track of termination, for setting the `is_truncated` field on Experience and for early stopping interaction with env
2134
2138
 
2135
2139
  is_terminated = full((batch,), False, device = device)
@@ -2251,6 +2255,8 @@ class DynamicsWorldModel(Module):
2251
2255
  video = cat((video, next_frame), dim = 2)
2252
2256
  rewards = safe_cat((rewards, reward), dim = 1)
2253
2257
 
2258
+ acc_agent_embed = safe_cat((acc_agent_embed, agent_embed), dim = 1)
2259
+
2254
2260
  # package up one experience for learning
2255
2261
 
2256
2262
  batch, device = latents.shape[0], latents.device
@@ -2262,6 +2268,7 @@ class DynamicsWorldModel(Module):
2262
2268
  actions = (discrete_actions, continuous_actions),
2263
2269
  log_probs = (discrete_log_probs, continuous_log_probs),
2264
2270
  values = values,
2271
+ agent_embed = acc_agent_embed if store_agent_embed else None,
2265
2272
  step_size = step_size,
2266
2273
  agent_index = agent_index,
2267
2274
  is_truncated = is_truncated,
@@ -2491,6 +2498,7 @@ class DynamicsWorldModel(Module):
2491
2498
  return_agent_actions = False,
2492
2499
  return_log_probs_and_values = False,
2493
2500
  return_time_kv_cache = False,
2501
+ store_agent_embed = False
2494
2502
 
2495
2503
  ): # (b t n d) | (b c t h w)
2496
2504
 
@@ -2543,6 +2551,10 @@ class DynamicsWorldModel(Module):
2543
2551
  decoded_continuous_log_probs = None
2544
2552
  decoded_values = None
2545
2553
 
2554
+ # maybe store agent embed
2555
+
2556
+ acc_agent_embed = None
2557
+
2546
2558
  # maybe return rewards
2547
2559
 
2548
2560
  decoded_rewards = None
@@ -2651,6 +2663,10 @@ class DynamicsWorldModel(Module):
2651
2663
 
2652
2664
  decoded_rewards = cat((decoded_rewards, pred_reward), dim = 1)
2653
2665
 
2666
+ # maybe store agent embed
2667
+
2668
+ acc_agent_embed = safe_cat((acc_agent_embed, agent_embed), dim = 1)
2669
+
2654
2670
  # decode the agent actions if needed
2655
2671
 
2656
2672
  if return_agent_actions:
@@ -2747,6 +2763,7 @@ class DynamicsWorldModel(Module):
2747
2763
  latents = latents,
2748
2764
  video = video,
2749
2765
  proprio = proprio if has_proprio else None,
2766
+ agent_embed = acc_agent_embed if store_agent_embed else None,
2750
2767
  step_size = step_size,
2751
2768
  agent_index = agent_index,
2752
2769
  lens = experience_lens,
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "dreamer4"
3
- version = "0.0.83"
3
+ version = "0.0.85"
4
4
  description = "Dreamer 4"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -614,11 +614,13 @@ def test_cache_generate():
614
614
  @param('use_signed_advantage', (False, True))
615
615
  @param('env_can_terminate', (False, True))
616
616
  @param('env_can_truncate', (False, True))
617
+ @param('store_agent_embed', (False, True))
617
618
  def test_online_rl(
618
619
  vectorized,
619
620
  use_signed_advantage,
620
621
  env_can_terminate,
621
- env_can_truncate
622
+ env_can_truncate,
623
+ store_agent_embed
622
624
  ):
623
625
  from dreamer4.dreamer4 import DynamicsWorldModel, VideoTokenizer
624
626
 
@@ -664,11 +666,14 @@ def test_online_rl(
664
666
 
665
667
  # manually
666
668
 
667
- one_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 8, env_is_vectorized = vectorized)
668
- another_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16, env_is_vectorized = vectorized)
669
+ one_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 8, env_is_vectorized = vectorized, store_agent_embed = store_agent_embed)
670
+ another_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16, env_is_vectorized = vectorized, store_agent_embed = store_agent_embed)
669
671
 
670
672
  combined_experience = combine_experiences([one_experience, another_experience])
671
673
 
674
+ if store_agent_embed:
675
+ assert exists(combined_experience.agent_embed)
676
+
672
677
  actor_loss, critic_loss = world_model_and_policy.learn_from_experience(combined_experience, use_signed_advantage = use_signed_advantage)
673
678
 
674
679
  actor_loss.backward()
File without changes
File without changes
File without changes
File without changes
File without changes