dreamer4 0.0.70__tar.gz → 0.0.72__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dreamer4 might be problematic. Click here for more details.

@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dreamer4
3
- Version: 0.0.70
3
+ Version: 0.0.72
4
4
  Summary: Dreamer 4
5
5
  Project-URL: Homepage, https://pypi.org/project/dreamer4/
6
6
  Project-URL: Repository, https://github.com/lucidrains/dreamer4
@@ -45,6 +45,7 @@ from assoc_scan import AssocScan
45
45
  # vc - video channels
46
46
  # vh, vw - video height and width
47
47
  # mtp - multi token prediction length
48
+ # v - video viewpoints
48
49
 
49
50
  import einx
50
51
  from einx import add, multiply
@@ -154,6 +155,15 @@ def is_power_two(num):
154
155
  def is_empty(t):
155
156
  return t.numel() == 0
156
157
 
158
+ def lens_to_mask(t, max_len = None):
159
+ if not exists(max_len):
160
+ max_len = t.amax().item()
161
+
162
+ device = t.device
163
+ seq = torch.arange(max_len, device = device)
164
+
165
+ return einx.less('j, i -> i j', seq, t)
166
+
157
167
  def log(t, eps = 1e-20):
158
168
  return t.clamp(min = eps).log()
159
169
 
@@ -1735,6 +1745,7 @@ class DynamicsWorldModel(Module):
1735
1745
  num_latent_tokens = None,
1736
1746
  num_agents = 1,
1737
1747
  num_tasks = 0,
1748
+ num_video_views = 1,
1738
1749
  dim_proprio = None,
1739
1750
  reward_encoder_kwargs: dict = dict(),
1740
1751
  depth = 4,
@@ -1800,7 +1811,7 @@ class DynamicsWorldModel(Module):
1800
1811
  )
1801
1812
 
1802
1813
  self.to_latent_pred = Sequential(
1803
- Reduce('b t n s d -> b t n d', 'mean'),
1814
+ Reduce('b t v n s d -> b t v n d', 'mean'),
1804
1815
  RMSNorm(dim),
1805
1816
  LinearNoBias(dim, dim_latent)
1806
1817
  )
@@ -1810,17 +1821,27 @@ class DynamicsWorldModel(Module):
1810
1821
  latent_tokens_to_space = num_latent_tokens // num_spatial_tokens
1811
1822
 
1812
1823
  self.latents_to_spatial_tokens = Sequential(
1813
- Rearrange('b t n d -> b t (n d)'),
1824
+ Rearrange('... n d -> ... (n d)'),
1814
1825
  Linear(num_latent_tokens * dim_latent, dim * num_spatial_tokens),
1815
- Rearrange('b t (s d) -> b t s d', s = num_spatial_tokens)
1826
+ Rearrange('... (s d) -> ... s d', s = num_spatial_tokens)
1816
1827
  )
1817
1828
 
1818
1829
  self.to_latent_pred = Sequential(
1819
1830
  RMSNorm(dim),
1820
1831
  LinearNoBias(dim, dim_latent * latent_tokens_to_space),
1821
- Rearrange('b t s (n d) -> b t (s n) d', n = latent_tokens_to_space)
1832
+ Rearrange('b t v s (n d) -> b t v (s n) d', n = latent_tokens_to_space)
1822
1833
  )
1823
1834
 
1835
+ # number of video views, for robotics, which could have third person + wrist camera at least
1836
+
1837
+ assert num_video_views >= 1
1838
+ self.video_has_multi_view = num_video_views > 1
1839
+
1840
+ self.num_video_views = num_video_views
1841
+
1842
+ if self.video_has_multi_view:
1843
+ self.view_emb = nn.Parameter(torch.randn(num_video_views, dim) * 1e-2)
1844
+
1824
1845
  # proprioception
1825
1846
 
1826
1847
  self.has_proprio = exists(dim_proprio)
@@ -2318,7 +2339,7 @@ class DynamicsWorldModel(Module):
2318
2339
  # denoising
2319
2340
  # teacher forcing to start with
2320
2341
 
2321
- latents = empty((batch_size, 0, *latent_shape), device = self.device)
2342
+ latents = empty((batch_size, 0, self.num_video_views, *latent_shape), device = self.device)
2322
2343
 
2323
2344
  past_latents_context_noise = latents.clone()
2324
2345
 
@@ -2354,7 +2375,7 @@ class DynamicsWorldModel(Module):
2354
2375
 
2355
2376
  curr_time_steps = latents.shape[1]
2356
2377
 
2357
- noised_latent = randn((batch_size, 1, *latent_shape), device = self.device)
2378
+ noised_latent = randn((batch_size, 1, self.num_video_views, *latent_shape), device = self.device)
2358
2379
 
2359
2380
  noised_proprio = None
2360
2381
 
@@ -2365,12 +2386,12 @@ class DynamicsWorldModel(Module):
2365
2386
  is_last_step = (step + 1) == num_steps
2366
2387
 
2367
2388
  signal_levels = full((batch_size, 1), step * step_size, dtype = torch.long, device = self.device)
2368
-
2389
+
2369
2390
  # noising past latent context
2370
2391
 
2371
2392
  noised_context = latents.lerp(past_latents_context_noise, context_signal_noise) # the paragraph after eq (8)
2372
2393
 
2373
- noised_latent_with_context, pack_context_shape = pack((noised_context, noised_latent), 'b * n d')
2394
+ noised_latent_with_context, pack_context_shape = pack((noised_context, noised_latent), 'b * v n d')
2374
2395
 
2375
2396
  # handle proprio
2376
2397
 
@@ -2395,6 +2416,7 @@ class DynamicsWorldModel(Module):
2395
2416
  proprio = noised_proprio_with_context,
2396
2417
  time_kv_cache = time_kv_cache,
2397
2418
  latent_is_noised = True,
2419
+ latent_has_view_dim = True,
2398
2420
  return_pred_only = True,
2399
2421
  return_intermediates = True,
2400
2422
  )
@@ -2409,12 +2431,11 @@ class DynamicsWorldModel(Module):
2409
2431
 
2410
2432
  # unpack pred
2411
2433
 
2412
- _, pred = unpack(pred, pack_context_shape, 'b * n d')
2434
+ _, pred = unpack(pred, pack_context_shape, 'b * v n d')
2413
2435
 
2414
2436
  if has_proprio:
2415
2437
  _, pred_proprio = unpack(pred_proprio, pack_context_shape, 'b * d')
2416
2438
 
2417
-
2418
2439
  # derive flow, based on whether in x-space or not
2419
2440
 
2420
2441
  def denoise_step(pred, noised, signal_levels):
@@ -2507,12 +2528,26 @@ class DynamicsWorldModel(Module):
2507
2528
  video = None
2508
2529
 
2509
2530
  if return_decoded_video:
2531
+
2532
+ latents_for_video = rearrange(latents, 'b t v n d -> b v t n d')
2533
+ latents_for_video, unpack_view = pack_one(latents_for_video, '* t n d')
2534
+
2510
2535
  video = self.video_tokenizer.decode(
2511
- latents,
2536
+ latents_for_video,
2512
2537
  height = image_height,
2513
2538
  width = image_width
2514
2539
  )
2515
2540
 
2541
+ video = unpack_view(video, '* t c vh vw')
2542
+
2543
+ # remove the lone view dimension
2544
+
2545
+ if not self.video_has_multi_view:
2546
+ latents = rearrange(latents, 'b t 1 ... -> b t ...')
2547
+
2548
+ if exists(video):
2549
+ video = rearrange(video, 'b 1 ... -> b ...')
2550
+
2516
2551
  # only return video or latent if not requesting anything else, for first stage training
2517
2552
 
2518
2553
  if not has_at_least_one(return_rewards_per_frame, return_agent_actions, has_proprio):
@@ -2553,8 +2588,9 @@ class DynamicsWorldModel(Module):
2553
2588
  def forward(
2554
2589
  self,
2555
2590
  *,
2556
- video = None, # (b c t vh vw)
2557
- latents = None, # (b t n d) | (b t d)
2591
+ video = None, # (b v? c t vh vw)
2592
+ latents = None, # (b t v? n d) | (b t v? d)
2593
+ lens = None, # (b)
2558
2594
  signal_levels = None, # () | (b) | (b t)
2559
2595
  step_sizes = None, # () | (b)
2560
2596
  step_sizes_log2 = None, # () | (b)
@@ -2572,21 +2608,39 @@ class DynamicsWorldModel(Module):
2572
2608
  return_all_losses = False,
2573
2609
  return_intermediates = False,
2574
2610
  add_autoregressive_action_loss = False,
2575
- update_loss_ema = None
2611
+ update_loss_ema = None,
2612
+ latent_has_view_dim = False
2576
2613
  ):
2577
2614
  # handle video or latents
2578
2615
 
2579
2616
  assert exists(video) ^ exists(latents)
2580
2617
 
2618
+ # standardize view dimension
2619
+
2620
+ if not self.video_has_multi_view:
2621
+ if exists(video):
2622
+ video = rearrange(video, 'b ... -> b 1 ...')
2623
+
2624
+ if exists(latents) and not latent_has_view_dim:
2625
+ latents = rearrange(latents, 'b t ... -> b t 1 ...')
2626
+
2627
+ # if raw video passed in, tokenize
2628
+
2581
2629
  if exists(video):
2630
+ assert video.ndim == 6
2631
+
2632
+ video, unpack_views = pack_one(video, '* c t vh vw')
2582
2633
  assert exists(self.video_tokenizer), 'video_tokenizer must be passed in if training from raw video on dynamics model'
2583
2634
 
2584
2635
  latents = self.video_tokenizer.tokenize(video)
2636
+ latents = unpack_views(latents, '* t n d')
2637
+ latents = rearrange(latents, 'b v t n d -> b t v n d')
2585
2638
 
2586
- if latents.ndim == 3:
2587
- latents = rearrange(latents, 'b t d -> b t 1 d') # 1 latent edge case
2639
+ if latents.ndim == 4:
2640
+ latents = rearrange(latents, 'b t v d -> b t v 1 d') # 1 latent edge case
2588
2641
 
2589
2642
  assert latents.shape[-2:] == self.latent_shape
2643
+ assert latents.shape[2] == self.num_video_views
2590
2644
 
2591
2645
  # variables
2592
2646
 
@@ -2769,6 +2823,7 @@ class DynamicsWorldModel(Module):
2769
2823
  # main function, needs to be defined as such for shortcut training - additional calls for consistency loss
2770
2824
 
2771
2825
  def get_prediction(noised_latents, noised_proprio, signal_levels, step_sizes_log2, action_tokens, reward_tokens, agent_tokens, return_agent_tokens = False, return_time_kv_cache = False):
2826
+
2772
2827
  # latents to spatial tokens
2773
2828
 
2774
2829
  space_tokens = self.latents_to_spatial_tokens(noised_latents)
@@ -2969,7 +3024,19 @@ class DynamicsWorldModel(Module):
2969
3024
 
2970
3025
  flow_losses = flow_losses * loss_weight
2971
3026
 
2972
- flow_loss = flow_losses.mean()
3027
+ # handle variable lengths if needed
3028
+
3029
+ is_var_len = exists(lens)
3030
+
3031
+ if is_var_len:
3032
+
3033
+ loss_mask = lens_to_mask(lens, time)
3034
+ loss_mask_without_last = loss_mask[:, :-1]
3035
+
3036
+ flow_loss = flow_losses[loss_mask].mean()
3037
+
3038
+ else:
3039
+ flow_loss = flow_losses.mean()
2973
3040
 
2974
3041
  # now take care of the agent token losses
2975
3042
 
@@ -2992,7 +3059,10 @@ class DynamicsWorldModel(Module):
2992
3059
 
2993
3060
  reward_losses = reward_losses.masked_fill(reward_loss_mask, 0.)
2994
3061
 
2995
- reward_loss = reduce(reward_losses, '... mtp -> mtp', 'mean') # they sum across the prediction steps (mtp dimension) - eq(9)
3062
+ if is_var_len:
3063
+ reward_loss = reward_losses[loss_mask_without_last].mean(dim = 0)
3064
+ else:
3065
+ reward_loss = reduce(reward_losses, '... mtp -> mtp', 'mean') # they sum across the prediction steps (mtp dimension) - eq(9)
2996
3066
 
2997
3067
  # maybe autoregressive action loss
2998
3068
 
@@ -3035,12 +3105,20 @@ class DynamicsWorldModel(Module):
3035
3105
  if exists(discrete_log_probs):
3036
3106
  discrete_log_probs = discrete_log_probs.masked_fill(~discrete_mask[..., None], 0.)
3037
3107
 
3038
- discrete_action_loss = reduce(-discrete_log_probs, 'mtp b t na -> mtp', 'mean')
3108
+ if is_var_len:
3109
+ discrete_action_losses = rearrange(-discrete_log_probs, 'mtp b t na -> b t na mtp')
3110
+ discrete_action_loss = reduce(discrete_action_losses[loss_mask_without_last], '... mtp -> mtp', 'mean')
3111
+ else:
3112
+ discrete_action_loss = reduce(-discrete_log_probs, 'mtp b t na -> mtp', 'mean')
3039
3113
 
3040
3114
  if exists(continuous_log_probs):
3041
3115
  continuous_log_probs = continuous_log_probs.masked_fill(~continuous_mask[..., None], 0.)
3042
3116
 
3043
- continuous_action_loss = reduce(-continuous_log_probs, 'mtp b t na -> mtp', 'mean')
3117
+ if is_var_len:
3118
+ continuous_action_losses = rearrange(-continuous_log_probs, 'mtp b t na -> b t na mtp')
3119
+ continuous_action_loss = reduce(continuous_action_losses[loss_mask_without_last], '... mtp -> mtp', 'mean')
3120
+ else:
3121
+ continuous_action_loss = reduce(-continuous_log_probs, 'mtp b t na -> mtp', 'mean')
3044
3122
 
3045
3123
  # handle loss normalization
3046
3124
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "dreamer4"
3
- version = "0.0.70"
3
+ version = "0.0.72"
4
4
  description = "Dreamer 4"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -16,6 +16,7 @@ def exists(v):
16
16
  @param('num_residual_streams', (1, 4))
17
17
  @param('add_reward_embed_to_agent_token', (False, True))
18
18
  @param('use_time_kv_cache', (False, True))
19
+ @param('var_len', (False, True))
19
20
  def test_e2e(
20
21
  pred_orig_latent,
21
22
  grouped_query_attn,
@@ -27,7 +28,8 @@ def test_e2e(
27
28
  condition_on_actions,
28
29
  num_residual_streams,
29
30
  add_reward_embed_to_agent_token,
30
- use_time_kv_cache
31
+ use_time_kv_cache,
32
+ var_len
31
33
  ):
32
34
  from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel
33
35
 
@@ -95,8 +97,13 @@ def test_e2e(
95
97
  if condition_on_actions:
96
98
  actions = torch.randint(0, 4, (2, 3, 1))
97
99
 
100
+ lens = None
101
+ if var_len:
102
+ lens = torch.randint(1, 4, (2,))
103
+
98
104
  flow_loss = dynamics(
99
105
  **dynamics_input,
106
+ lens = lens,
100
107
  tasks = tasks,
101
108
  signal_levels = signal_levels,
102
109
  step_sizes_log2 = step_sizes_log2,
@@ -668,7 +675,10 @@ def test_online_rl(
668
675
 
669
676
  trainer(mock_env, num_episodes = 2, env_is_vectorized = vectorized)
670
677
 
671
- def test_proprioception():
678
+ @param('num_video_views', (1, 2))
679
+ def test_proprioception(
680
+ num_video_views
681
+ ):
672
682
  from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel
673
683
 
674
684
  tokenizer = VideoTokenizer(
@@ -693,11 +703,17 @@ def test_proprioception():
693
703
  dim_latent = 32,
694
704
  dim_proprio = 21,
695
705
  num_tasks = 4,
706
+ num_video_views = num_video_views,
696
707
  num_discrete_actions = 4,
697
708
  num_residual_streams = 1
698
709
  )
699
710
 
700
- video = torch.randn(2, 3, 10, 256, 256)
711
+ if num_video_views > 1:
712
+ video_shape = (2, num_video_views, 3, 10, 256, 256)
713
+ else:
714
+ video_shape = (2, 3, 10, 256, 256)
715
+
716
+ video = torch.randn(*video_shape)
701
717
  rewards = torch.randn(2, 10)
702
718
  proprio = torch.randn(2, 10, 21)
703
719
  discrete_actions = torch.randint(0, 4, (2, 10, 1))
@@ -714,8 +730,10 @@ def test_proprioception():
714
730
  loss.backward()
715
731
 
716
732
  generations = dynamics.generate(
717
- 4,
733
+ 10,
718
734
  batch_size = 2,
735
+ return_decoded_video = True
719
736
  )
720
737
 
721
738
  assert exists(generations.proprio)
739
+ assert generations.video.shape == video_shape
File without changes
File without changes
File without changes
File without changes
File without changes