dreamer4 0.0.99__tar.gz → 0.1.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dreamer4
3
- Version: 0.0.99
3
+ Version: 0.1.5
4
4
  Summary: Dreamer 4
5
5
  Project-URL: Homepage, https://pypi.org/project/dreamer4/
6
6
  Project-URL: Repository, https://github.com/lucidrains/dreamer4
@@ -53,11 +53,100 @@ Description-Content-Type: text/markdown
53
53
 
54
54
  <img src="./dreamer4-fig2.png" width="400px"></img>
55
55
 
56
- ## Dreamer 4 (wip)
56
+ ## Dreamer 4
57
57
 
58
58
  Implementation of Danijar's [latest iteration](https://arxiv.org/abs/2509.24527v1) for his [Dreamer](https://danijar.com/project/dreamer4/) line of work
59
59
 
60
- [Temporary Discord](https://discord.gg/MkACrrkrYR)
60
+ [Discord channel](https://discord.gg/ab4BEk3W) for collaborating with other researchers interested in this work
61
+
62
+ ## Appreciation
63
+
64
+ - [@dirkmcpherson](https://github.com/dirkmcpherson) for fixes to typo errors and unpassed arguments!
65
+
66
+ ## Install
67
+
68
+ ```bash
69
+ $ pip install dreamer4
70
+ ```
71
+
72
+ ## Usage
73
+
74
+ ```python
75
+ import torch
76
+ from dreamer4 import VideoTokenizer, DynamicsWorldModel
77
+
78
+ # video tokenizer, learned through MAE + lpips
79
+
80
+ tokenizer = VideoTokenizer(
81
+ dim = 512,
82
+ dim_latent = 32,
83
+ patch_size = 32,
84
+ image_height = 256,
85
+ image_width = 256
86
+ )
87
+
88
+ video = torch.randn(2, 3, 10, 256, 256)
89
+
90
+ # learn the tokenizer
91
+
92
+ loss = tokenizer(video)
93
+ loss.backward() # ler
94
+
95
+ # dynamics world model
96
+
97
+ world_model = DynamicsWorldModel(
98
+ dim = 512,
99
+ dim_latent = 32,
100
+ video_tokenizer = tokenizer,
101
+ num_discrete_actions = 4,
102
+ num_residual_streams = 1
103
+ )
104
+
105
+ # state, action, rewards
106
+
107
+ video = torch.randn(2, 3, 10, 256, 256)
108
+ discrete_actions = torch.randint(0, 4, (2, 10, 1))
109
+ rewards = torch.randn(2, 10)
110
+
111
+ # learn dynamics / behavior cloned model
112
+
113
+ loss = world_model(
114
+ video = video,
115
+ rewards = rewards,
116
+ discrete_actions = discrete_actions
117
+ )
118
+
119
+ loss.backward()
120
+
121
+ # do the above with much data
122
+
123
+ # then generate dreams
124
+
125
+ dreams = world_model.generate(
126
+ 10,
127
+ batch_size = 2,
128
+ return_decoded_video = True,
129
+ return_for_policy_optimization = True
130
+ )
131
+
132
+ # learn from the dreams
133
+
134
+ actor_loss, critic_loss = world_model.learn_from_experience(dreams)
135
+
136
+ (actor_loss + critic_loss).backward()
137
+
138
+ # learn from environment
139
+
140
+ from dreamer4.mocks import MockEnv
141
+
142
+ mock_env = MockEnv((256, 256), vectorized = True, num_envs = 4)
143
+
144
+ experience = world_model.interact_with_env(mock_env, max_timesteps = 8, env_is_vectorized = True)
145
+
146
+ actor_loss, critic_loss = world_model.learn_from_experience(experience)
147
+
148
+ (actor_loss + critic_loss).backward()
149
+ ```
61
150
 
62
151
  ## Citation
63
152
 
@@ -72,3 +161,5 @@ Implementation of Danijar's [latest iteration](https://arxiv.org/abs/2509.24527v
72
161
  url = {https://arxiv.org/abs/2509.24527},
73
162
  }
74
163
  ```
164
+
165
+ *the conquest of nature is to be achieved through number and measure - angels to Descartes in a dream*
@@ -0,0 +1,112 @@
1
+ <img src="./dreamer4-fig2.png" width="400px"></img>
2
+
3
+ ## Dreamer 4
4
+
5
+ Implementation of Danijar's [latest iteration](https://arxiv.org/abs/2509.24527v1) for his [Dreamer](https://danijar.com/project/dreamer4/) line of work
6
+
7
+ [Discord channel](https://discord.gg/ab4BEk3W) for collaborating with other researchers interested in this work
8
+
9
+ ## Appreciation
10
+
11
+ - [@dirkmcpherson](https://github.com/dirkmcpherson) for fixes to typo errors and unpassed arguments!
12
+
13
+ ## Install
14
+
15
+ ```bash
16
+ $ pip install dreamer4
17
+ ```
18
+
19
+ ## Usage
20
+
21
+ ```python
22
+ import torch
23
+ from dreamer4 import VideoTokenizer, DynamicsWorldModel
24
+
25
+ # video tokenizer, learned through MAE + lpips
26
+
27
+ tokenizer = VideoTokenizer(
28
+ dim = 512,
29
+ dim_latent = 32,
30
+ patch_size = 32,
31
+ image_height = 256,
32
+ image_width = 256
33
+ )
34
+
35
+ video = torch.randn(2, 3, 10, 256, 256)
36
+
37
+ # learn the tokenizer
38
+
39
+ loss = tokenizer(video)
40
+ loss.backward() # ler
41
+
42
+ # dynamics world model
43
+
44
+ world_model = DynamicsWorldModel(
45
+ dim = 512,
46
+ dim_latent = 32,
47
+ video_tokenizer = tokenizer,
48
+ num_discrete_actions = 4,
49
+ num_residual_streams = 1
50
+ )
51
+
52
+ # state, action, rewards
53
+
54
+ video = torch.randn(2, 3, 10, 256, 256)
55
+ discrete_actions = torch.randint(0, 4, (2, 10, 1))
56
+ rewards = torch.randn(2, 10)
57
+
58
+ # learn dynamics / behavior cloned model
59
+
60
+ loss = world_model(
61
+ video = video,
62
+ rewards = rewards,
63
+ discrete_actions = discrete_actions
64
+ )
65
+
66
+ loss.backward()
67
+
68
+ # do the above with much data
69
+
70
+ # then generate dreams
71
+
72
+ dreams = world_model.generate(
73
+ 10,
74
+ batch_size = 2,
75
+ return_decoded_video = True,
76
+ return_for_policy_optimization = True
77
+ )
78
+
79
+ # learn from the dreams
80
+
81
+ actor_loss, critic_loss = world_model.learn_from_experience(dreams)
82
+
83
+ (actor_loss + critic_loss).backward()
84
+
85
+ # learn from environment
86
+
87
+ from dreamer4.mocks import MockEnv
88
+
89
+ mock_env = MockEnv((256, 256), vectorized = True, num_envs = 4)
90
+
91
+ experience = world_model.interact_with_env(mock_env, max_timesteps = 8, env_is_vectorized = True)
92
+
93
+ actor_loss, critic_loss = world_model.learn_from_experience(experience)
94
+
95
+ (actor_loss + critic_loss).backward()
96
+ ```
97
+
98
+ ## Citation
99
+
100
+ ```bibtex
101
+ @misc{hafner2025trainingagentsinsidescalable,
102
+ title = {Training Agents Inside of Scalable World Models},
103
+ author = {Danijar Hafner and Wilson Yan and Timothy Lillicrap},
104
+ year = {2025},
105
+ eprint = {2509.24527},
106
+ archivePrefix = {arXiv},
107
+ primaryClass = {cs.AI},
108
+ url = {https://arxiv.org/abs/2509.24527},
109
+ }
110
+ ```
111
+
112
+ *the conquest of nature is to be achieved through number and measure - angels to Descartes in a dream*
@@ -14,7 +14,7 @@ from torch.nested import nested_tensor
14
14
  from torch.distributions import Normal, kl
15
15
  from torch.nn import Module, ModuleList, Embedding, Parameter, Sequential, Linear, RMSNorm, Identity
16
16
  from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, full, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange
17
- from torch.utils._pytree import tree_flatten, tree_unflatten
17
+ from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
18
18
 
19
19
  import torchvision
20
20
  from torchvision.models import VGG16_Weights
@@ -91,6 +91,14 @@ class Experience:
91
91
  agent_index: int = 0
92
92
  is_from_world_model: bool = True
93
93
 
94
+ def cpu(self):
95
+ return self.to(torch.device('cpu'))
96
+
97
+ def to(self, device):
98
+ experience_dict = asdict(self)
99
+ experience_dict = tree_map(lambda t: t.to(device) if is_tensor(t) else t, experience_dict)
100
+ return Experience(**experience_dict)
101
+
94
102
  def combine_experiences(
95
103
  exps: list[Experiences]
96
104
  ) -> Experience:
@@ -1179,10 +1187,11 @@ def special_token_mask(q, k, seq_len, num_tokens, special_attend_only_itself = F
1179
1187
 
1180
1188
  def block_mask_special_tokens_right(
1181
1189
  seq_len,
1182
- num_tokens
1190
+ num_tokens,
1191
+ special_attend_only_itself = False
1183
1192
  ):
1184
1193
  def inner(b, h, q, k):
1185
- return special_token_mask(q, k, seq_len, num_tokens)
1194
+ return special_token_mask(q, k, seq_len, num_tokens, special_attend_only_itself)
1186
1195
  return inner
1187
1196
 
1188
1197
  def compose_mask(mask1, mask2):
@@ -1331,6 +1340,12 @@ class Attention(Module):
1331
1340
  q = self.q_heads_rmsnorm(q)
1332
1341
  k = self.k_heads_rmsnorm(k)
1333
1342
 
1343
+ # rotary
1344
+
1345
+ if exists(rotary_pos_emb):
1346
+ q = apply_rotations(rotary_pos_emb, q)
1347
+ k = apply_rotations(rotary_pos_emb, k)
1348
+
1334
1349
  # caching
1335
1350
 
1336
1351
  if exists(kv_cache):
@@ -1338,12 +1353,6 @@ class Attention(Module):
1338
1353
  k = cat((ck, k), dim = -2)
1339
1354
  v = cat((cv, v), dim = -2)
1340
1355
 
1341
- # rotary
1342
-
1343
- if exists(rotary_pos_emb):
1344
- q = apply_rotations(rotary_pos_emb, q)
1345
- k = apply_rotations(rotary_pos_emb, k)
1346
-
1347
1356
  # attention
1348
1357
 
1349
1358
  attend_fn = default(attend_fn, naive_attend)
@@ -1493,7 +1502,8 @@ class AxialSpaceTimeTransformer(Module):
1493
1502
 
1494
1503
  # attend functions for space and time
1495
1504
 
1496
- use_flex = exists(flex_attention) and tokens.is_cuda
1505
+ has_kv_cache = exists(kv_cache)
1506
+ use_flex = exists(flex_attention) and tokens.is_cuda and not has_kv_cache # KV cache shape breaks flex attention TODO: Fix
1497
1507
 
1498
1508
  attend_kwargs = dict(use_flex = use_flex, softclamp_value = self.attn_softclamp_value, special_attend_only_itself = self.special_attend_only_itself, device = device)
1499
1509
 
@@ -1505,14 +1515,12 @@ class AxialSpaceTimeTransformer(Module):
1505
1515
 
1506
1516
  time_attn_kv_caches = []
1507
1517
 
1508
- has_kv_cache = exists(kv_cache)
1509
-
1510
1518
 
1511
1519
  if has_kv_cache:
1512
1520
  past_tokens, tokens = tokens[:, :-1], tokens[:, -1:]
1513
1521
 
1514
1522
  rotary_seq_len = 1
1515
- rotary_pos_offset = past_tokens.shape[-2]
1523
+ rotary_pos_offset = past_tokens.shape[1]
1516
1524
  else:
1517
1525
  rotary_seq_len = time
1518
1526
  rotary_pos_offset = 0
@@ -1687,6 +1695,7 @@ class VideoTokenizer(Module):
1687
1695
  time_block_every = time_block_every,
1688
1696
  num_special_spatial_tokens = num_latent_tokens,
1689
1697
  num_residual_streams = num_residual_streams,
1698
+ special_attend_only_itself = True,
1690
1699
  final_norm = True
1691
1700
  )
1692
1701
 
@@ -1847,7 +1856,7 @@ class VideoTokenizer(Module):
1847
1856
 
1848
1857
  losses = (recon_loss, lpips_loss)
1849
1858
 
1850
- return total_loss, TokenizerLosses(losses)
1859
+ return total_loss, TokenizerLosses(*losses)
1851
1860
 
1852
1861
  # dynamics model, axial space-time transformer
1853
1862
 
@@ -1900,7 +1909,9 @@ class DynamicsWorldModel(Module):
1900
1909
  gae_lambda = 0.95,
1901
1910
  ppo_eps_clip = 0.2,
1902
1911
  pmpo_pos_to_neg_weight = 0.5, # pos and neg equal weight
1903
- pmpo_kl_div_loss_weight = 1.,
1912
+ pmpo_reverse_kl = True,
1913
+ pmpo_kl_div_loss_weight = .3,
1914
+ normalize_advantages = None,
1904
1915
  value_clip = 0.4,
1905
1916
  policy_entropy_weight = .01,
1906
1917
  gae_use_accelerated = False
@@ -2102,12 +2113,13 @@ class DynamicsWorldModel(Module):
2102
2113
 
2103
2114
  self.ppo_eps_clip = ppo_eps_clip
2104
2115
  self.value_clip = value_clip
2105
- self.policy_entropy_weight = value_clip
2116
+ self.policy_entropy_weight = policy_entropy_weight
2106
2117
 
2107
2118
  # pmpo related
2108
2119
 
2109
2120
  self.pmpo_pos_to_neg_weight = pmpo_pos_to_neg_weight
2110
2121
  self.pmpo_kl_div_loss_weight = pmpo_kl_div_loss_weight
2122
+ self.pmpo_reverse_kl = pmpo_reverse_kl
2111
2123
 
2112
2124
  # rewards related
2113
2125
 
@@ -2124,7 +2136,7 @@ class DynamicsWorldModel(Module):
2124
2136
  self.flow_loss_normalizer = LossNormalizer(1)
2125
2137
  self.reward_loss_normalizer = LossNormalizer(multi_token_pred_len)
2126
2138
  self.discrete_actions_loss_normalizer = LossNormalizer(multi_token_pred_len) if num_discrete_actions > 0 else None
2127
- self.continuous_actions_loss_normalizer = LossNormalizer(multi_token_pred_len) if num_discrete_actions > 0 else None
2139
+ self.continuous_actions_loss_normalizer = LossNormalizer(multi_token_pred_len) if num_continuous_actions > 0 else None
2128
2140
 
2129
2141
  self.latent_flow_loss_weight = latent_flow_loss_weight
2130
2142
 
@@ -2355,6 +2367,9 @@ class DynamicsWorldModel(Module):
2355
2367
  elif len(env_step_out) == 4:
2356
2368
  next_frame, reward, terminated, truncated = env_step_out
2357
2369
 
2370
+ elif len(env_step_out) == 5:
2371
+ next_frame, reward, terminated, truncated, info = env_step_out
2372
+
2358
2373
  # update episode lens
2359
2374
 
2360
2375
  episode_lens = torch.where(is_terminated, episode_lens, episode_lens + 1)
@@ -2423,8 +2438,12 @@ class DynamicsWorldModel(Module):
2423
2438
  value_optim: Optimizer | None = None,
2424
2439
  only_learn_policy_value_heads = True, # in the paper, they do not finetune the entire dynamics model, they just learn the heads
2425
2440
  use_pmpo = True,
2441
+ normalize_advantages = None,
2426
2442
  eps = 1e-6
2427
2443
  ):
2444
+ assert isinstance(experience, Experience)
2445
+
2446
+ experience = experience.to(self.device)
2428
2447
 
2429
2448
  latents = experience.latents
2430
2449
  actions = experience.actions
@@ -2437,7 +2456,7 @@ class DynamicsWorldModel(Module):
2437
2456
  step_size = experience.step_size
2438
2457
  agent_index = experience.agent_index
2439
2458
 
2440
- assert all([*map(exists, (old_log_probs, actions, old_values, rewards, step_size))]), 'the generations need to contain the log probs, values, and rewards for policy optimization'
2459
+ assert all([*map(exists, (old_log_probs, actions, old_values, rewards, step_size))]), 'the generations need to contain the log probs, values, and rewards for policy optimization - world_model.generate(..., return_log_probs_and_values = True)'
2441
2460
 
2442
2461
  batch, time = latents.shape[0], latents.shape[1]
2443
2462
 
@@ -2451,8 +2470,8 @@ class DynamicsWorldModel(Module):
2451
2470
  if exists(experience.lens):
2452
2471
  mask_for_gae = lens_to_mask(experience.lens, time)
2453
2472
 
2454
- rewards = rewards.masked_fill(mask_for_gae, 0.)
2455
- old_values = old_values.masked_fill(mask_for_gae, 0.)
2473
+ rewards = rewards.masked_fill(~mask_for_gae, 0.)
2474
+ old_values = old_values.masked_fill(~mask_for_gae, 0.)
2456
2475
 
2457
2476
  # calculate returns
2458
2477
 
@@ -2487,7 +2506,7 @@ class DynamicsWorldModel(Module):
2487
2506
 
2488
2507
  # mean, var - todo - handle distributed
2489
2508
 
2490
- returns_mean, returns_var = returns.mean(), returns.var()
2509
+ returns_mean, returns_var = returns_for_stats.mean(), returns_for_stats.var()
2491
2510
 
2492
2511
  # ema
2493
2512
 
@@ -2505,16 +2524,19 @@ class DynamicsWorldModel(Module):
2505
2524
  else:
2506
2525
  advantage = returns - old_values
2507
2526
 
2508
- # apparently they just use the sign of the advantage
2527
+ # if using pmpo, do not normalize advantages, but can be overridden
2528
+
2529
+ normalize_advantages = default(normalize_advantages, not use_pmpo)
2530
+
2531
+ if normalize_advantages:
2532
+ advantage = F.layer_norm(advantage, advantage.shape, eps = eps)
2533
+
2509
2534
  # https://arxiv.org/abs/2410.04166v1
2510
2535
 
2511
2536
  if use_pmpo:
2512
2537
  pos_advantage_mask = advantage >= 0.
2513
2538
  neg_advantage_mask = ~pos_advantage_mask
2514
2539
 
2515
- else:
2516
- advantage = F.layer_norm(advantage, advantage.shape, eps = eps)
2517
-
2518
2540
  # replay for the action logits and values
2519
2541
  # but only do so if fine tuning the entire world model for RL
2520
2542
 
@@ -2578,11 +2600,18 @@ class DynamicsWorldModel(Module):
2578
2600
  # take care of kl
2579
2601
 
2580
2602
  if self.pmpo_kl_div_loss_weight > 0.:
2603
+
2581
2604
  new_unembedded_actions = self.action_embedder.unembed(policy_embed, pred_head_index = 0)
2582
2605
 
2606
+ kl_div_inputs, kl_div_targets = new_unembedded_actions, old_action_unembeds
2607
+
2583
2608
  # mentioned that the "reverse direction for the prior KL" was used
2609
+ # make optional, as observed instability in toy task
2610
+
2611
+ if self.pmpo_reverse_kl:
2612
+ kl_div_inputs, kl_div_targets = kl_div_targets, kl_div_inputs
2584
2613
 
2585
- discrete_kl_div, continuous_kl_div = self.action_embedder.kl_div(old_action_unembeds, new_unembedded_actions)
2614
+ discrete_kl_div, continuous_kl_div = self.action_embedder.kl_div(kl_div_inputs, kl_div_targets)
2586
2615
 
2587
2616
  # accumulate discrete and continuous kl div
2588
2617
 
@@ -2680,12 +2709,22 @@ class DynamicsWorldModel(Module):
2680
2709
  return_rewards_per_frame = False,
2681
2710
  return_agent_actions = False,
2682
2711
  return_log_probs_and_values = False,
2712
+ return_for_policy_optimization = False,
2683
2713
  return_time_kv_cache = False,
2684
2714
  store_agent_embed = True,
2685
2715
  store_old_action_unembeds = True
2686
2716
 
2687
2717
  ): # (b t n d) | (b c t h w)
2688
2718
 
2719
+ # handy flag for returning generations for rl
2720
+
2721
+ if return_for_policy_optimization:
2722
+ return_agent_actions |= True
2723
+ return_log_probs_and_values |= True
2724
+ return_rewards_per_frame |= True
2725
+
2726
+ # more variables
2727
+
2689
2728
  has_proprio = self.has_proprio
2690
2729
  was_training = self.training
2691
2730
  self.eval()
@@ -2755,6 +2794,19 @@ class DynamicsWorldModel(Module):
2755
2794
 
2756
2795
  curr_time_steps = latents.shape[1]
2757
2796
 
2797
+ # determine whether to take an extra step if
2798
+ # (1) using time kv cache
2799
+ # (2) decoding anything off agent embedding (rewards, actions, etc)
2800
+
2801
+ take_extra_step = (
2802
+ use_time_kv_cache or
2803
+ return_rewards_per_frame or
2804
+ store_agent_embed or
2805
+ return_agent_actions
2806
+ )
2807
+
2808
+ # prepare noised latent / proprio inputs
2809
+
2758
2810
  noised_latent = randn((batch_size, 1, self.num_video_views, *latent_shape), device = self.device)
2759
2811
 
2760
2812
  noised_proprio = None
@@ -2762,7 +2814,10 @@ class DynamicsWorldModel(Module):
2762
2814
  if has_proprio:
2763
2815
  noised_proprio = randn((batch_size, 1, self.dim_proprio), device = self.device)
2764
2816
 
2765
- for step in range(num_steps):
2817
+ # denoising steps
2818
+
2819
+ for step in range(num_steps + int(take_extra_step)):
2820
+
2766
2821
  is_last_step = (step + 1) == num_steps
2767
2822
 
2768
2823
  signal_levels = full((batch_size, 1), step * step_size, dtype = torch.long, device = self.device)
@@ -2805,6 +2860,11 @@ class DynamicsWorldModel(Module):
2805
2860
  if use_time_kv_cache and is_last_step:
2806
2861
  time_kv_cache = next_time_kv_cache
2807
2862
 
2863
+ # early break if taking an extra step for agent embedding off cleaned latents for decoding
2864
+
2865
+ if take_extra_step and is_last_step:
2866
+ break
2867
+
2808
2868
  # maybe proprio
2809
2869
 
2810
2870
  if has_proprio:
@@ -3007,7 +3067,7 @@ class DynamicsWorldModel(Module):
3007
3067
  latent_is_noised = False,
3008
3068
  return_all_losses = False,
3009
3069
  return_intermediates = False,
3010
- add_autoregressive_action_loss = False,
3070
+ add_autoregressive_action_loss = True,
3011
3071
  update_loss_ema = None,
3012
3072
  latent_has_view_dim = False
3013
3073
  ):
@@ -3039,8 +3099,8 @@ class DynamicsWorldModel(Module):
3039
3099
  if latents.ndim == 4:
3040
3100
  latents = rearrange(latents, 'b t v d -> b t v 1 d') # 1 latent edge case
3041
3101
 
3042
- assert latents.shape[-2:] == self.latent_shape
3043
- assert latents.shape[2] == self.num_video_views
3102
+ assert latents.shape[-2:] == self.latent_shape, f'latents must have shape {self.latent_shape}, got {latents.shape[-2:]}'
3103
+ assert latents.shape[2] == self.num_video_views, f'latents must have {self.num_video_views} views, got {latents.shape[2]}'
3044
3104
 
3045
3105
  # variables
3046
3106
 
@@ -3464,7 +3524,7 @@ class DynamicsWorldModel(Module):
3464
3524
 
3465
3525
  reward_losses = F.cross_entropy(reward_pred, reward_targets, reduction = 'none')
3466
3526
 
3467
- reward_losses = reward_losses.masked_fill(reward_loss_mask, 0.)
3527
+ reward_losses = reward_losses.masked_fill(~reward_loss_mask, 0.)
3468
3528
 
3469
3529
  if is_var_len:
3470
3530
  reward_loss = reward_losses[loss_mask_without_last].mean(dim = 0)
@@ -3508,7 +3568,7 @@ class DynamicsWorldModel(Module):
3508
3568
  discrete_mask = rearrange(discrete_mask, 'b t mtp -> mtp b t')
3509
3569
 
3510
3570
  if exists(continuous_actions):
3511
- continuous_action_targets, continuous_mask = create_multi_token_prediction_targets(discrete_actions, self.multi_token_pred_len)
3571
+ continuous_action_targets, continuous_mask = create_multi_token_prediction_targets(continuous_actions, self.multi_token_pred_len)
3512
3572
  continuous_action_targets = rearrange(continuous_action_targets, 'b t mtp ... -> mtp b t ...')
3513
3573
  continuous_mask = rearrange(continuous_mask, 'b t mtp -> mtp b t')
3514
3574
 
@@ -528,7 +528,7 @@ class SimTrainer(Module):
528
528
 
529
529
  total_experience += num_experience
530
530
 
531
- experiences.append(experience)
531
+ experiences.append(experience.cpu())
532
532
 
533
533
  combined_experiences = combine_experiences(experiences)
534
534
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "dreamer4"
3
- version = "0.0.99"
3
+ version = "0.1.5"
4
4
  description = "Dreamer 4"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -680,6 +680,12 @@ def test_online_rl(
680
680
 
681
681
  combined_experience = combine_experiences([one_experience, another_experience])
682
682
 
683
+ # quick test moving the experience to different devices
684
+
685
+ if torch.cuda.is_available():
686
+ combined_experience = combined_experience.to(torch.device('cuda'))
687
+ combined_experience = combined_experience.to(world_model_and_policy.device)
688
+
683
689
  if store_agent_embed:
684
690
  assert exists(combined_experience.agent_embed)
685
691
 
dreamer4-0.0.99/README.md DELETED
@@ -1,21 +0,0 @@
1
- <img src="./dreamer4-fig2.png" width="400px"></img>
2
-
3
- ## Dreamer 4 (wip)
4
-
5
- Implementation of Danijar's [latest iteration](https://arxiv.org/abs/2509.24527v1) for his [Dreamer](https://danijar.com/project/dreamer4/) line of work
6
-
7
- [Temporary Discord](https://discord.gg/MkACrrkrYR)
8
-
9
- ## Citation
10
-
11
- ```bibtex
12
- @misc{hafner2025trainingagentsinsidescalable,
13
- title = {Training Agents Inside of Scalable World Models},
14
- author = {Danijar Hafner and Wilson Yan and Timothy Lillicrap},
15
- year = {2025},
16
- eprint = {2509.24527},
17
- archivePrefix = {arXiv},
18
- primaryClass = {cs.AI},
19
- url = {https://arxiv.org/abs/2509.24527},
20
- }
21
- ```
File without changes
File without changes
File without changes
File without changes
File without changes