dreamer4 0.0.70__tar.gz → 0.0.71__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dreamer4 might be problematic. Click here for more details.

@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dreamer4
3
- Version: 0.0.70
3
+ Version: 0.0.71
4
4
  Summary: Dreamer 4
5
5
  Project-URL: Homepage, https://pypi.org/project/dreamer4/
6
6
  Project-URL: Repository, https://github.com/lucidrains/dreamer4
@@ -45,6 +45,7 @@ from assoc_scan import AssocScan
45
45
  # vc - video channels
46
46
  # vh, vw - video height and width
47
47
  # mtp - multi token prediction length
48
+ # v - video viewpoints
48
49
 
49
50
  import einx
50
51
  from einx import add, multiply
@@ -1735,6 +1736,7 @@ class DynamicsWorldModel(Module):
1735
1736
  num_latent_tokens = None,
1736
1737
  num_agents = 1,
1737
1738
  num_tasks = 0,
1739
+ num_video_views = 1,
1738
1740
  dim_proprio = None,
1739
1741
  reward_encoder_kwargs: dict = dict(),
1740
1742
  depth = 4,
@@ -1800,7 +1802,7 @@ class DynamicsWorldModel(Module):
1800
1802
  )
1801
1803
 
1802
1804
  self.to_latent_pred = Sequential(
1803
- Reduce('b t n s d -> b t n d', 'mean'),
1805
+ Reduce('b t v n s d -> b t v n d', 'mean'),
1804
1806
  RMSNorm(dim),
1805
1807
  LinearNoBias(dim, dim_latent)
1806
1808
  )
@@ -1810,17 +1812,27 @@ class DynamicsWorldModel(Module):
1810
1812
  latent_tokens_to_space = num_latent_tokens // num_spatial_tokens
1811
1813
 
1812
1814
  self.latents_to_spatial_tokens = Sequential(
1813
- Rearrange('b t n d -> b t (n d)'),
1815
+ Rearrange('... n d -> ... (n d)'),
1814
1816
  Linear(num_latent_tokens * dim_latent, dim * num_spatial_tokens),
1815
- Rearrange('b t (s d) -> b t s d', s = num_spatial_tokens)
1817
+ Rearrange('... (s d) -> ... s d', s = num_spatial_tokens)
1816
1818
  )
1817
1819
 
1818
1820
  self.to_latent_pred = Sequential(
1819
1821
  RMSNorm(dim),
1820
1822
  LinearNoBias(dim, dim_latent * latent_tokens_to_space),
1821
- Rearrange('b t s (n d) -> b t (s n) d', n = latent_tokens_to_space)
1823
+ Rearrange('b t v s (n d) -> b t v (s n) d', n = latent_tokens_to_space)
1822
1824
  )
1823
1825
 
1826
+ # number of video views, for robotics, which could have third person + wrist camera at least
1827
+
1828
+ assert num_video_views >= 1
1829
+ self.video_has_multi_view = num_video_views > 1
1830
+
1831
+ self.num_video_views = num_video_views
1832
+
1833
+ if self.video_has_multi_view:
1834
+ self.view_emb = nn.Parameter(torch.randn(num_video_views, dim) * 1e-2)
1835
+
1824
1836
  # proprioception
1825
1837
 
1826
1838
  self.has_proprio = exists(dim_proprio)
@@ -2318,7 +2330,7 @@ class DynamicsWorldModel(Module):
2318
2330
  # denoising
2319
2331
  # teacher forcing to start with
2320
2332
 
2321
- latents = empty((batch_size, 0, *latent_shape), device = self.device)
2333
+ latents = empty((batch_size, 0, self.num_video_views, *latent_shape), device = self.device)
2322
2334
 
2323
2335
  past_latents_context_noise = latents.clone()
2324
2336
 
@@ -2354,7 +2366,7 @@ class DynamicsWorldModel(Module):
2354
2366
 
2355
2367
  curr_time_steps = latents.shape[1]
2356
2368
 
2357
- noised_latent = randn((batch_size, 1, *latent_shape), device = self.device)
2369
+ noised_latent = randn((batch_size, 1, self.num_video_views, *latent_shape), device = self.device)
2358
2370
 
2359
2371
  noised_proprio = None
2360
2372
 
@@ -2365,12 +2377,12 @@ class DynamicsWorldModel(Module):
2365
2377
  is_last_step = (step + 1) == num_steps
2366
2378
 
2367
2379
  signal_levels = full((batch_size, 1), step * step_size, dtype = torch.long, device = self.device)
2368
-
2380
+
2369
2381
  # noising past latent context
2370
2382
 
2371
2383
  noised_context = latents.lerp(past_latents_context_noise, context_signal_noise) # the paragraph after eq (8)
2372
2384
 
2373
- noised_latent_with_context, pack_context_shape = pack((noised_context, noised_latent), 'b * n d')
2385
+ noised_latent_with_context, pack_context_shape = pack((noised_context, noised_latent), 'b * v n d')
2374
2386
 
2375
2387
  # handle proprio
2376
2388
 
@@ -2395,6 +2407,7 @@ class DynamicsWorldModel(Module):
2395
2407
  proprio = noised_proprio_with_context,
2396
2408
  time_kv_cache = time_kv_cache,
2397
2409
  latent_is_noised = True,
2410
+ latent_has_view_dim = True,
2398
2411
  return_pred_only = True,
2399
2412
  return_intermediates = True,
2400
2413
  )
@@ -2409,12 +2422,11 @@ class DynamicsWorldModel(Module):
2409
2422
 
2410
2423
  # unpack pred
2411
2424
 
2412
- _, pred = unpack(pred, pack_context_shape, 'b * n d')
2425
+ _, pred = unpack(pred, pack_context_shape, 'b * v n d')
2413
2426
 
2414
2427
  if has_proprio:
2415
2428
  _, pred_proprio = unpack(pred_proprio, pack_context_shape, 'b * d')
2416
2429
 
2417
-
2418
2430
  # derive flow, based on whether in x-space or not
2419
2431
 
2420
2432
  def denoise_step(pred, noised, signal_levels):
@@ -2507,12 +2519,26 @@ class DynamicsWorldModel(Module):
2507
2519
  video = None
2508
2520
 
2509
2521
  if return_decoded_video:
2522
+
2523
+ latents_for_video = rearrange(latents, 'b t v n d -> b v t n d')
2524
+ latents_for_video, unpack_view = pack_one(latents_for_video, '* t n d')
2525
+
2510
2526
  video = self.video_tokenizer.decode(
2511
- latents,
2527
+ latents_for_video,
2512
2528
  height = image_height,
2513
2529
  width = image_width
2514
2530
  )
2515
2531
 
2532
+ video = unpack_view(video, '* t c vh vw')
2533
+
2534
+ # remove the lone view dimension
2535
+
2536
+ if not self.video_has_multi_view:
2537
+ latents = rearrange(latents, 'b t 1 ... -> b t ...')
2538
+
2539
+ if exists(video):
2540
+ video = rearrange(video, 'b 1 ... -> b ...')
2541
+
2516
2542
  # only return video or latent if not requesting anything else, for first stage training
2517
2543
 
2518
2544
  if not has_at_least_one(return_rewards_per_frame, return_agent_actions, has_proprio):
@@ -2553,8 +2579,8 @@ class DynamicsWorldModel(Module):
2553
2579
  def forward(
2554
2580
  self,
2555
2581
  *,
2556
- video = None, # (b c t vh vw)
2557
- latents = None, # (b t n d) | (b t d)
2582
+ video = None, # (b v? c t vh vw)
2583
+ latents = None, # (b t v? n d) | (b t v? d)
2558
2584
  signal_levels = None, # () | (b) | (b t)
2559
2585
  step_sizes = None, # () | (b)
2560
2586
  step_sizes_log2 = None, # () | (b)
@@ -2572,21 +2598,39 @@ class DynamicsWorldModel(Module):
2572
2598
  return_all_losses = False,
2573
2599
  return_intermediates = False,
2574
2600
  add_autoregressive_action_loss = False,
2575
- update_loss_ema = None
2601
+ update_loss_ema = None,
2602
+ latent_has_view_dim = False
2576
2603
  ):
2577
2604
  # handle video or latents
2578
2605
 
2579
2606
  assert exists(video) ^ exists(latents)
2580
2607
 
2608
+ # standardize view dimension
2609
+
2610
+ if not self.video_has_multi_view:
2611
+ if exists(video):
2612
+ video = rearrange(video, 'b ... -> b 1 ...')
2613
+
2614
+ if exists(latents) and not latent_has_view_dim:
2615
+ latents = rearrange(latents, 'b t ... -> b t 1 ...')
2616
+
2617
+ # if raw video passed in, tokenize
2618
+
2581
2619
  if exists(video):
2620
+ assert video.ndim == 6
2621
+
2622
+ video, unpack_views = pack_one(video, '* c t vh vw')
2582
2623
  assert exists(self.video_tokenizer), 'video_tokenizer must be passed in if training from raw video on dynamics model'
2583
2624
 
2584
2625
  latents = self.video_tokenizer.tokenize(video)
2626
+ latents = unpack_views(latents, '* t n d')
2627
+ latents = rearrange(latents, 'b v t n d -> b t v n d')
2585
2628
 
2586
- if latents.ndim == 3:
2587
- latents = rearrange(latents, 'b t d -> b t 1 d') # 1 latent edge case
2629
+ if latents.ndim == 4:
2630
+ latents = rearrange(latents, 'b t v d -> b t v 1 d') # 1 latent edge case
2588
2631
 
2589
2632
  assert latents.shape[-2:] == self.latent_shape
2633
+ assert latents.shape[2] == self.num_video_views
2590
2634
 
2591
2635
  # variables
2592
2636
 
@@ -2769,6 +2813,7 @@ class DynamicsWorldModel(Module):
2769
2813
  # main function, needs to be defined as such for shortcut training - additional calls for consistency loss
2770
2814
 
2771
2815
  def get_prediction(noised_latents, noised_proprio, signal_levels, step_sizes_log2, action_tokens, reward_tokens, agent_tokens, return_agent_tokens = False, return_time_kv_cache = False):
2816
+
2772
2817
  # latents to spatial tokens
2773
2818
 
2774
2819
  space_tokens = self.latents_to_spatial_tokens(noised_latents)
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "dreamer4"
3
- version = "0.0.70"
3
+ version = "0.0.71"
4
4
  description = "Dreamer 4"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -668,7 +668,10 @@ def test_online_rl(
668
668
 
669
669
  trainer(mock_env, num_episodes = 2, env_is_vectorized = vectorized)
670
670
 
671
- def test_proprioception():
671
+ @param('num_video_views', (1, 2))
672
+ def test_proprioception(
673
+ num_video_views
674
+ ):
672
675
  from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel
673
676
 
674
677
  tokenizer = VideoTokenizer(
@@ -693,11 +696,16 @@ def test_proprioception():
693
696
  dim_latent = 32,
694
697
  dim_proprio = 21,
695
698
  num_tasks = 4,
699
+ num_video_views = num_video_views,
696
700
  num_discrete_actions = 4,
697
701
  num_residual_streams = 1
698
702
  )
699
703
 
700
- video = torch.randn(2, 3, 10, 256, 256)
704
+ if num_video_views > 1:
705
+ video = torch.randn(2, num_video_views, 3, 10, 256, 256)
706
+ else:
707
+ video = torch.randn(2, 3, 10, 256, 256)
708
+
701
709
  rewards = torch.randn(2, 10)
702
710
  proprio = torch.randn(2, 10, 21)
703
711
  discrete_actions = torch.randint(0, 4, (2, 10, 1))
File without changes
File without changes
File without changes
File without changes
File without changes