dreamer4 0.0.71__tar.gz → 0.0.73__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dreamer4 might be problematic. Click here for more details.

@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dreamer4
3
- Version: 0.0.71
3
+ Version: 0.0.73
4
4
  Summary: Dreamer 4
5
5
  Project-URL: Homepage, https://pypi.org/project/dreamer4/
6
6
  Project-URL: Repository, https://github.com/lucidrains/dreamer4
@@ -82,6 +82,7 @@ class Experience:
82
82
  log_probs: tuple[Tensor, Tensor] | None = None
83
83
  values: Tensor | None = None
84
84
  step_size: int | None = None
85
+ lens: Tensor | None = None,
85
86
  agent_index: int = 0
86
87
  is_from_world_model: bool = True
87
88
 
@@ -155,6 +156,15 @@ def is_power_two(num):
155
156
  def is_empty(t):
156
157
  return t.numel() == 0
157
158
 
159
+ def lens_to_mask(t, max_len = None):
160
+ if not exists(max_len):
161
+ max_len = t.amax().item()
162
+
163
+ device = t.device
164
+ seq = torch.arange(max_len, device = device)
165
+
166
+ return einx.less('j, i -> i j', seq, t)
167
+
158
168
  def log(t, eps = 1e-20):
159
169
  return t.clamp(min = eps).log()
160
170
 
@@ -2551,12 +2561,16 @@ class DynamicsWorldModel(Module):
2551
2561
 
2552
2562
  # returning agent actions, rewards, and log probs + values for policy optimization
2553
2563
 
2564
+ batch, device = latents.shape[0], latents.device
2565
+ experience_lens = torch.full((batch,), time_steps, device = device)
2566
+
2554
2567
  gen = Experience(
2555
2568
  latents = latents,
2556
2569
  video = video,
2557
2570
  proprio = proprio if has_proprio else None,
2558
2571
  step_size = step_size,
2559
2572
  agent_index = agent_index,
2573
+ lens = experience_lens,
2560
2574
  is_from_world_model = True
2561
2575
  )
2562
2576
 
@@ -2581,6 +2595,7 @@ class DynamicsWorldModel(Module):
2581
2595
  *,
2582
2596
  video = None, # (b v? c t vh vw)
2583
2597
  latents = None, # (b t v? n d) | (b t v? d)
2598
+ lens = None, # (b)
2584
2599
  signal_levels = None, # () | (b) | (b t)
2585
2600
  step_sizes = None, # () | (b)
2586
2601
  step_sizes_log2 = None, # () | (b)
@@ -3014,7 +3029,19 @@ class DynamicsWorldModel(Module):
3014
3029
 
3015
3030
  flow_losses = flow_losses * loss_weight
3016
3031
 
3017
- flow_loss = flow_losses.mean()
3032
+ # handle variable lengths if needed
3033
+
3034
+ is_var_len = exists(lens)
3035
+
3036
+ if is_var_len:
3037
+
3038
+ loss_mask = lens_to_mask(lens, time)
3039
+ loss_mask_without_last = loss_mask[:, :-1]
3040
+
3041
+ flow_loss = flow_losses[loss_mask].mean()
3042
+
3043
+ else:
3044
+ flow_loss = flow_losses.mean()
3018
3045
 
3019
3046
  # now take care of the agent token losses
3020
3047
 
@@ -3037,7 +3064,10 @@ class DynamicsWorldModel(Module):
3037
3064
 
3038
3065
  reward_losses = reward_losses.masked_fill(reward_loss_mask, 0.)
3039
3066
 
3040
- reward_loss = reduce(reward_losses, '... mtp -> mtp', 'mean') # they sum across the prediction steps (mtp dimension) - eq(9)
3067
+ if is_var_len:
3068
+ reward_loss = reward_losses[loss_mask_without_last].mean(dim = 0)
3069
+ else:
3070
+ reward_loss = reduce(reward_losses, '... mtp -> mtp', 'mean') # they sum across the prediction steps (mtp dimension) - eq(9)
3041
3071
 
3042
3072
  # maybe autoregressive action loss
3043
3073
 
@@ -3080,12 +3110,20 @@ class DynamicsWorldModel(Module):
3080
3110
  if exists(discrete_log_probs):
3081
3111
  discrete_log_probs = discrete_log_probs.masked_fill(~discrete_mask[..., None], 0.)
3082
3112
 
3083
- discrete_action_loss = reduce(-discrete_log_probs, 'mtp b t na -> mtp', 'mean')
3113
+ if is_var_len:
3114
+ discrete_action_losses = rearrange(-discrete_log_probs, 'mtp b t na -> b t na mtp')
3115
+ discrete_action_loss = reduce(discrete_action_losses[loss_mask_without_last], '... mtp -> mtp', 'mean')
3116
+ else:
3117
+ discrete_action_loss = reduce(-discrete_log_probs, 'mtp b t na -> mtp', 'mean')
3084
3118
 
3085
3119
  if exists(continuous_log_probs):
3086
3120
  continuous_log_probs = continuous_log_probs.masked_fill(~continuous_mask[..., None], 0.)
3087
3121
 
3088
- continuous_action_loss = reduce(-continuous_log_probs, 'mtp b t na -> mtp', 'mean')
3122
+ if is_var_len:
3123
+ continuous_action_losses = rearrange(-continuous_log_probs, 'mtp b t na -> b t na mtp')
3124
+ continuous_action_loss = reduce(continuous_action_losses[loss_mask_without_last], '... mtp -> mtp', 'mean')
3125
+ else:
3126
+ continuous_action_loss = reduce(-continuous_log_probs, 'mtp b t na -> mtp', 'mean')
3089
3127
 
3090
3128
  # handle loss normalization
3091
3129
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "dreamer4"
3
- version = "0.0.71"
3
+ version = "0.0.73"
4
4
  description = "Dreamer 4"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -16,6 +16,7 @@ def exists(v):
16
16
  @param('num_residual_streams', (1, 4))
17
17
  @param('add_reward_embed_to_agent_token', (False, True))
18
18
  @param('use_time_kv_cache', (False, True))
19
+ @param('var_len', (False, True))
19
20
  def test_e2e(
20
21
  pred_orig_latent,
21
22
  grouped_query_attn,
@@ -27,7 +28,8 @@ def test_e2e(
27
28
  condition_on_actions,
28
29
  num_residual_streams,
29
30
  add_reward_embed_to_agent_token,
30
- use_time_kv_cache
31
+ use_time_kv_cache,
32
+ var_len
31
33
  ):
32
34
  from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel
33
35
 
@@ -95,8 +97,13 @@ def test_e2e(
95
97
  if condition_on_actions:
96
98
  actions = torch.randint(0, 4, (2, 3, 1))
97
99
 
100
+ lens = None
101
+ if var_len:
102
+ lens = torch.randint(1, 4, (2,))
103
+
98
104
  flow_loss = dynamics(
99
105
  **dynamics_input,
106
+ lens = lens,
100
107
  tasks = tasks,
101
108
  signal_levels = signal_levels,
102
109
  step_sizes_log2 = step_sizes_log2,
@@ -702,10 +709,11 @@ def test_proprioception(
702
709
  )
703
710
 
704
711
  if num_video_views > 1:
705
- video = torch.randn(2, num_video_views, 3, 10, 256, 256)
712
+ video_shape = (2, num_video_views, 3, 10, 256, 256)
706
713
  else:
707
- video = torch.randn(2, 3, 10, 256, 256)
714
+ video_shape = (2, 3, 10, 256, 256)
708
715
 
716
+ video = torch.randn(*video_shape)
709
717
  rewards = torch.randn(2, 10)
710
718
  proprio = torch.randn(2, 10, 21)
711
719
  discrete_actions = torch.randint(0, 4, (2, 10, 1))
@@ -722,8 +730,10 @@ def test_proprioception(
722
730
  loss.backward()
723
731
 
724
732
  generations = dynamics.generate(
725
- 4,
733
+ 10,
726
734
  batch_size = 2,
735
+ return_decoded_video = True
727
736
  )
728
737
 
729
738
  assert exists(generations.proprio)
739
+ assert generations.video.shape == video_shape
File without changes
File without changes
File without changes
File without changes
File without changes