dreamer4 0.0.94__tar.gz → 0.0.95__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dreamer4 might be problematic. Click here for more details.

@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dreamer4
3
- Version: 0.0.94
3
+ Version: 0.0.95
4
4
  Summary: Dreamer 4
5
5
  Project-URL: Homepage, https://pypi.org/project/dreamer4/
6
6
  Project-URL: Repository, https://github.com/lucidrains/dreamer4
@@ -72,20 +72,22 @@ TokenizerLosses = namedtuple('TokenizerLosses', ('recon', 'lpips'))
72
72
 
73
73
  WorldModelLosses = namedtuple('WorldModelLosses', ('flow', 'rewards', 'discrete_actions', 'continuous_actions'))
74
74
 
75
+ MaybeTensor = Tensor | None
76
+
75
77
  @dataclass
76
78
  class Experience:
77
79
  latents: Tensor
78
- video: Tensor | None = None
79
- proprio: Tensor | None = None
80
- agent_embed: Tensor | None = None
80
+ video: MaybeTensor = None
81
+ proprio: MaybeTensor = None
82
+ agent_embed: MaybeTensor = None
81
83
  rewards: Tensor | None = None
82
- actions: tuple[Tensor, Tensor] | None = None
83
- log_probs: tuple[Tensor, Tensor] | None = None
84
- old_action_unembeds: tuple[Tensor, Tensor] | None = None
85
- values: Tensor | None = None
84
+ actions: tuple[MaybeTensor, MaybeTensor] | None = None
85
+ log_probs: tuple[MaybeTensor, MaybeTensor] | None = None
86
+ old_action_unembeds: tuple[MaybeTensor, MaybeTensor] | None = None
87
+ values: MaybeTensor = None
86
88
  step_size: int | None = None
87
- lens: Tensor | None = None
88
- is_truncated: Tensor | None = None
89
+ lens: MaybeTensor = None
90
+ is_truncated: MaybeTensor = None
89
91
  agent_index: int = 0
90
92
  is_from_world_model: bool = True
91
93
 
@@ -850,9 +852,10 @@ class ActionEmbedder(Module):
850
852
 
851
853
  def kl_div(
852
854
  self,
853
- src: tuple[Tensor | None, Tensor | None],
854
- tgt: tuple[Tensor | None, Tensor | None]
855
- ) -> tuple[Tensor | None, Tensor | None]:
855
+ src: tuple[MaybeTensor, MaybeTensor],
856
+ tgt: tuple[MaybeTensor, MaybeTensor],
857
+ reduce_across_num_actions = True
858
+ ) -> tuple[MaybeTensor, MaybeTensor]:
856
859
 
857
860
  src_discrete, src_continuous = src
858
861
  tgt_discrete, tgt_continuous = tgt
@@ -894,6 +897,15 @@ class ActionEmbedder(Module):
894
897
 
895
898
  continuous_kl_div = kl.kl_divergence(src_normal, tgt_normal)
896
899
 
900
+ # maybe reduce
901
+
902
+ if reduce_across_num_actions:
903
+ if exists(discrete_kl_div):
904
+ discrete_kl_div = discrete_kl_div.sum(dim = -1)
905
+
906
+ if exists(continuous_kl_div):
907
+ continuous_kl_div = continuous_kl_div.sum(dim = -1)
908
+
897
909
  return discrete_kl_div, continuous_kl_div
898
910
 
899
911
  def forward(
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "dreamer4"
3
- version = "0.0.94"
3
+ version = "0.0.95"
4
4
  description = "Dreamer 4"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -352,8 +352,8 @@ def test_action_embedder():
352
352
 
353
353
  discrete_kl_div, continuous_kl_div = embedder.kl_div((discrete_logits, continuous_mean_log_var), (discrete_logits_tgt, continuous_mean_log_var_tgt))
354
354
 
355
- assert discrete_kl_div.shape == (2, 3, 2)
356
- assert continuous_kl_div.shape == (2, 3, 2)
355
+ assert discrete_kl_div.shape == (2, 3)
356
+ assert continuous_kl_div.shape == (2, 3)
357
357
 
358
358
  # return discrete split by number of actions
359
359
 
File without changes
File without changes
File without changes
File without changes
File without changes