dreamer4 0.0.83__py3-none-any.whl → 0.0.84__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dreamer4/dreamer4.py CHANGED
@@ -77,6 +77,7 @@ class Experience:
77
77
  latents: Tensor
78
78
  video: Tensor | None = None
79
79
  proprio: Tensor | None = None
80
+ agent_embed: Tensor | None = None,
80
81
  rewards: Tensor | None = None
81
82
  actions: tuple[Tensor, Tensor] | None = None
82
83
  log_probs: tuple[Tensor, Tensor] | None = None
@@ -2105,7 +2106,8 @@ class DynamicsWorldModel(Module):
2105
2106
  step_size = 4,
2106
2107
  max_timesteps = 16,
2107
2108
  env_is_vectorized = False,
2108
- use_time_kv_cache = True
2109
+ use_time_kv_cache = True,
2110
+ store_agent_embed = True
2109
2111
  ):
2110
2112
  assert exists(self.video_tokenizer)
2111
2113
 
@@ -2130,6 +2132,8 @@ class DynamicsWorldModel(Module):
2130
2132
  values = None
2131
2133
  latents = None
2132
2134
 
2135
+ acc_agent_embed = None
2136
+
2133
2137
  # keep track of termination, for setting the `is_truncated` field on Experience and for early stopping interaction with env
2134
2138
 
2135
2139
  is_terminated = full((batch,), False, device = device)
@@ -2251,6 +2255,8 @@ class DynamicsWorldModel(Module):
2251
2255
  video = cat((video, next_frame), dim = 2)
2252
2256
  rewards = safe_cat((rewards, reward), dim = 1)
2253
2257
 
2258
+ acc_agent_embed = safe_cat((acc_agent_embed, agent_embed), dim = 1)
2259
+
2254
2260
  # package up one experience for learning
2255
2261
 
2256
2262
  batch, device = latents.shape[0], latents.device
@@ -2262,6 +2268,7 @@ class DynamicsWorldModel(Module):
2262
2268
  actions = (discrete_actions, continuous_actions),
2263
2269
  log_probs = (discrete_log_probs, continuous_log_probs),
2264
2270
  values = values,
2271
+ agent_embed = acc_agent_embed if store_agent_embed else None,
2265
2272
  step_size = step_size,
2266
2273
  agent_index = agent_index,
2267
2274
  is_truncated = is_truncated,
@@ -2491,6 +2498,7 @@ class DynamicsWorldModel(Module):
2491
2498
  return_agent_actions = False,
2492
2499
  return_log_probs_and_values = False,
2493
2500
  return_time_kv_cache = False,
2501
+ store_agent_embed = False
2494
2502
 
2495
2503
  ): # (b t n d) | (b c t h w)
2496
2504
 
@@ -2543,6 +2551,10 @@ class DynamicsWorldModel(Module):
2543
2551
  decoded_continuous_log_probs = None
2544
2552
  decoded_values = None
2545
2553
 
2554
+ # maybe store agent embed
2555
+
2556
+ acc_agent_embed = None
2557
+
2546
2558
  # maybe return rewards
2547
2559
 
2548
2560
  decoded_rewards = None
@@ -2651,6 +2663,10 @@ class DynamicsWorldModel(Module):
2651
2663
 
2652
2664
  decoded_rewards = cat((decoded_rewards, pred_reward), dim = 1)
2653
2665
 
2666
+ # maybe store agent embed
2667
+
2668
+ acc_agent_embed = safe_cat((acc_agent_embed, agent_embed), dim = 1)
2669
+
2654
2670
  # decode the agent actions if needed
2655
2671
 
2656
2672
  if return_agent_actions:
@@ -2747,6 +2763,7 @@ class DynamicsWorldModel(Module):
2747
2763
  latents = latents,
2748
2764
  video = video,
2749
2765
  proprio = proprio if has_proprio else None,
2766
+ agent_embed = acc_agent_embed if store_agent_embed else None,
2750
2767
  step_size = step_size,
2751
2768
  agent_index = agent_index,
2752
2769
  lens = experience_lens,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dreamer4
3
- Version: 0.0.83
3
+ Version: 0.0.84
4
4
  Summary: Dreamer 4
5
5
  Project-URL: Homepage, https://pypi.org/project/dreamer4/
6
6
  Project-URL: Repository, https://github.com/lucidrains/dreamer4
@@ -0,0 +1,8 @@
1
+ dreamer4/__init__.py,sha256=Jssh1obzDRtTfBLZl36kXge1cIQlMjf_8DyjPulvKSk,183
2
+ dreamer4/dreamer4.py,sha256=jldgU8K1fSx8Eb9v1VYUhtbVSYZnHIEtQgx-WMnzep4,111820
3
+ dreamer4/mocks.py,sha256=TfqOB_Gq6N_GggBYwa6ZAJQx38ntlYbXZe23Ne4jshw,2502
4
+ dreamer4/trainers.py,sha256=D2b7WTgTHElLhIWLFgl2Ct2knGJLTk91HHpC5UkNvG0,14028
5
+ dreamer4-0.0.84.dist-info/METADATA,sha256=4D7elrDaDKYR2VhD94MZ5gsCFohvKIAkD1SFf_atx3w,3065
6
+ dreamer4-0.0.84.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ dreamer4-0.0.84.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ dreamer4-0.0.84.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- dreamer4/__init__.py,sha256=Jssh1obzDRtTfBLZl36kXge1cIQlMjf_8DyjPulvKSk,183
2
- dreamer4/dreamer4.py,sha256=t9g-fJVRPn4nefpNiKwqkIkKZtIeJAn5V4ruNCJ5a9A,111265
3
- dreamer4/mocks.py,sha256=TfqOB_Gq6N_GggBYwa6ZAJQx38ntlYbXZe23Ne4jshw,2502
4
- dreamer4/trainers.py,sha256=D2b7WTgTHElLhIWLFgl2Ct2knGJLTk91HHpC5UkNvG0,14028
5
- dreamer4-0.0.83.dist-info/METADATA,sha256=pYIw5Pj41JhfJ4Hp9AfegXrxnhIZ9Fk98kF7lY3nAZk,3065
6
- dreamer4-0.0.83.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- dreamer4-0.0.83.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- dreamer4-0.0.83.dist-info/RECORD,,