dreamer4 0.0.93__py3-none-any.whl → 0.0.95__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dreamer4 might be problematic. Click here for more details.

dreamer4/dreamer4.py CHANGED
@@ -72,20 +72,22 @@ TokenizerLosses = namedtuple('TokenizerLosses', ('recon', 'lpips'))
72
72
 
73
73
  WorldModelLosses = namedtuple('WorldModelLosses', ('flow', 'rewards', 'discrete_actions', 'continuous_actions'))
74
74
 
75
+ MaybeTensor = Tensor | None
76
+
75
77
  @dataclass
76
78
  class Experience:
77
79
  latents: Tensor
78
- video: Tensor | None = None
79
- proprio: Tensor | None = None
80
- agent_embed: Tensor | None = None
80
+ video: MaybeTensor = None
81
+ proprio: MaybeTensor = None
82
+ agent_embed: MaybeTensor = None
81
83
  rewards: Tensor | None = None
82
- actions: tuple[Tensor, Tensor] | None = None
83
- log_probs: tuple[Tensor, Tensor] | None = None
84
- old_action_unembeds: tuple[Tensor, Tensor] | None = None
85
- values: Tensor | None = None
84
+ actions: tuple[MaybeTensor, MaybeTensor] | None = None
85
+ log_probs: tuple[MaybeTensor, MaybeTensor] | None = None
86
+ old_action_unembeds: tuple[MaybeTensor, MaybeTensor] | None = None
87
+ values: MaybeTensor = None
86
88
  step_size: int | None = None
87
- lens: Tensor | None = None
88
- is_truncated: Tensor | None = None
89
+ lens: MaybeTensor = None
90
+ is_truncated: MaybeTensor = None
89
91
  agent_index: int = 0
90
92
  is_from_world_model: bool = True
91
93
 
@@ -850,9 +852,10 @@ class ActionEmbedder(Module):
850
852
 
851
853
  def kl_div(
852
854
  self,
853
- src: tuple[Tensor | None, Tensor | None],
854
- tgt: tuple[Tensor | None, Tensor | None]
855
- ) -> tuple[Tensor | None, Tensor | None]:
855
+ src: tuple[MaybeTensor, MaybeTensor],
856
+ tgt: tuple[MaybeTensor, MaybeTensor],
857
+ reduce_across_num_actions = True
858
+ ) -> tuple[MaybeTensor, MaybeTensor]:
856
859
 
857
860
  src_discrete, src_continuous = src
858
861
  tgt_discrete, tgt_continuous = tgt
@@ -894,6 +897,15 @@ class ActionEmbedder(Module):
894
897
 
895
898
  continuous_kl_div = kl.kl_divergence(src_normal, tgt_normal)
896
899
 
900
+ # maybe reduce
901
+
902
+ if reduce_across_num_actions:
903
+ if exists(discrete_kl_div):
904
+ discrete_kl_div = discrete_kl_div.sum(dim = -1)
905
+
906
+ if exists(continuous_kl_div):
907
+ continuous_kl_div = continuous_kl_div.sum(dim = -1)
908
+
897
909
  return discrete_kl_div, continuous_kl_div
898
910
 
899
911
  def forward(
@@ -2225,8 +2237,8 @@ class DynamicsWorldModel(Module):
2225
2237
  max_timesteps = 16,
2226
2238
  env_is_vectorized = False,
2227
2239
  use_time_kv_cache = True,
2228
- store_agent_embed = False,
2229
- store_old_action_unembeds = False,
2240
+ store_agent_embed = True,
2241
+ store_old_action_unembeds = True,
2230
2242
  ):
2231
2243
  assert exists(self.video_tokenizer)
2232
2244
 
@@ -2391,7 +2403,7 @@ class DynamicsWorldModel(Module):
2391
2403
  actions = (discrete_actions, continuous_actions),
2392
2404
  log_probs = (discrete_log_probs, continuous_log_probs),
2393
2405
  values = values,
2394
- old_action_unembeds = self.action_embedder.unembed(acc_policy_embed, pred_head_index = 0) if store_old_action_unembeds else None,
2406
+ old_action_unembeds = self.action_embedder.unembed(acc_policy_embed, pred_head_index = 0) if exists(acc_policy_embed) and store_old_action_unembeds else None,
2395
2407
  agent_embed = acc_agent_embed if store_agent_embed else None,
2396
2408
  step_size = step_size,
2397
2409
  agent_index = agent_index,
@@ -2667,7 +2679,8 @@ class DynamicsWorldModel(Module):
2667
2679
  return_agent_actions = False,
2668
2680
  return_log_probs_and_values = False,
2669
2681
  return_time_kv_cache = False,
2670
- store_agent_embed = False
2682
+ store_agent_embed = True,
2683
+ store_old_action_unembeds = True
2671
2684
 
2672
2685
  ): # (b t n d) | (b c t h w)
2673
2686
 
@@ -2947,7 +2960,7 @@ class DynamicsWorldModel(Module):
2947
2960
  video = video,
2948
2961
  proprio = proprio if has_proprio else None,
2949
2962
  agent_embed = acc_agent_embed if store_agent_embed else None,
2950
- old_action_unembeds = self.action_embedder.unembed(acc_policy_embed, pred_head_index = 0) if store_old_action_unembeds else None,
2963
+ old_action_unembeds = self.action_embedder.unembed(acc_policy_embed, pred_head_index = 0) if exists(acc_policy_embed) and store_old_action_unembeds else None,
2951
2964
  step_size = step_size,
2952
2965
  agent_index = agent_index,
2953
2966
  lens = experience_lens,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dreamer4
3
- Version: 0.0.93
3
+ Version: 0.0.95
4
4
  Summary: Dreamer 4
5
5
  Project-URL: Homepage, https://pypi.org/project/dreamer4/
6
6
  Project-URL: Repository, https://github.com/lucidrains/dreamer4
@@ -0,0 +1,8 @@
1
+ dreamer4/__init__.py,sha256=Jssh1obzDRtTfBLZl36kXge1cIQlMjf_8DyjPulvKSk,183
2
+ dreamer4/dreamer4.py,sha256=pSwW4DpHMF78ortsHLVHDdWzrsjse6bIVxX3oolA-Ao,118572
3
+ dreamer4/mocks.py,sha256=TfqOB_Gq6N_GggBYwa6ZAJQx38ntlYbXZe23Ne4jshw,2502
4
+ dreamer4/trainers.py,sha256=D2b7WTgTHElLhIWLFgl2Ct2knGJLTk91HHpC5UkNvG0,14028
5
+ dreamer4-0.0.95.dist-info/METADATA,sha256=VWiRy5xotYsn9HL5EFymn9N8j8-_wFKbYeTA5k6E0z4,3065
6
+ dreamer4-0.0.95.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ dreamer4-0.0.95.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ dreamer4-0.0.95.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- dreamer4/__init__.py,sha256=Jssh1obzDRtTfBLZl36kXge1cIQlMjf_8DyjPulvKSk,183
2
- dreamer4/dreamer4.py,sha256=B2Bk6JJO9MVTWwss9hOP1k6SBiEr56ijNOa3PiidPnY,118120
3
- dreamer4/mocks.py,sha256=TfqOB_Gq6N_GggBYwa6ZAJQx38ntlYbXZe23Ne4jshw,2502
4
- dreamer4/trainers.py,sha256=D2b7WTgTHElLhIWLFgl2Ct2knGJLTk91HHpC5UkNvG0,14028
5
- dreamer4-0.0.93.dist-info/METADATA,sha256=FhVnlhfeloUPMiFqqJ5qR6fqdd7YmN1-gXykkOTPF_A,3065
6
- dreamer4-0.0.93.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- dreamer4-0.0.93.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- dreamer4-0.0.93.dist-info/RECORD,,