dreamer4 0.0.102__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dreamer4 might be problematic. Click here for more details.

dreamer4/dreamer4.py CHANGED
@@ -1331,6 +1331,12 @@ class Attention(Module):
1331
1331
  q = self.q_heads_rmsnorm(q)
1332
1332
  k = self.k_heads_rmsnorm(k)
1333
1333
 
1334
+ # rotary
1335
+
1336
+ if exists(rotary_pos_emb):
1337
+ q = apply_rotations(rotary_pos_emb, q)
1338
+ k = apply_rotations(rotary_pos_emb, k)
1339
+
1334
1340
  # caching
1335
1341
 
1336
1342
  if exists(kv_cache):
@@ -1338,12 +1344,6 @@ class Attention(Module):
1338
1344
  k = cat((ck, k), dim = -2)
1339
1345
  v = cat((cv, v), dim = -2)
1340
1346
 
1341
- # rotary
1342
-
1343
- if exists(rotary_pos_emb):
1344
- q = apply_rotations(rotary_pos_emb, q)
1345
- k = apply_rotations(rotary_pos_emb, k)
1346
-
1347
1347
  # attention
1348
1348
 
1349
1349
  attend_fn = default(attend_fn, naive_attend)
@@ -1507,12 +1507,11 @@ class AxialSpaceTimeTransformer(Module):
1507
1507
 
1508
1508
  has_kv_cache = exists(kv_cache)
1509
1509
 
1510
-
1511
1510
  if has_kv_cache:
1512
1511
  past_tokens, tokens = tokens[:, :-1], tokens[:, -1:]
1513
1512
 
1514
1513
  rotary_seq_len = 1
1515
- rotary_pos_offset = past_tokens.shape[-2]
1514
+ rotary_pos_offset = past_tokens.shape[1]
1516
1515
  else:
1517
1516
  rotary_seq_len = time
1518
1517
  rotary_pos_offset = 0
@@ -1687,6 +1686,7 @@ class VideoTokenizer(Module):
1687
1686
  time_block_every = time_block_every,
1688
1687
  num_special_spatial_tokens = num_latent_tokens,
1689
1688
  num_residual_streams = num_residual_streams,
1689
+ special_attend_only_itself = True,
1690
1690
  final_norm = True
1691
1691
  )
1692
1692
 
@@ -2429,6 +2429,7 @@ class DynamicsWorldModel(Module):
2429
2429
  normalize_advantages = None,
2430
2430
  eps = 1e-6
2431
2431
  ):
2432
+ assert isinstance(experience, Experience)
2432
2433
 
2433
2434
  latents = experience.latents
2434
2435
  actions = experience.actions
@@ -2441,7 +2442,7 @@ class DynamicsWorldModel(Module):
2441
2442
  step_size = experience.step_size
2442
2443
  agent_index = experience.agent_index
2443
2444
 
2444
- assert all([*map(exists, (old_log_probs, actions, old_values, rewards, step_size))]), 'the generations need to contain the log probs, values, and rewards for policy optimization'
2445
+ assert all([*map(exists, (old_log_probs, actions, old_values, rewards, step_size))]), 'the generations need to contain the log probs, values, and rewards for policy optimization - world_model.generate(..., return_log_probs_and_values = True)'
2445
2446
 
2446
2447
  batch, time = latents.shape[0], latents.shape[1]
2447
2448
 
@@ -2455,8 +2456,8 @@ class DynamicsWorldModel(Module):
2455
2456
  if exists(experience.lens):
2456
2457
  mask_for_gae = lens_to_mask(experience.lens, time)
2457
2458
 
2458
- rewards = rewards.masked_fill(mask_for_gae, 0.)
2459
- old_values = old_values.masked_fill(mask_for_gae, 0.)
2459
+ rewards = rewards.masked_fill(~mask_for_gae, 0.)
2460
+ old_values = old_values.masked_fill(~mask_for_gae, 0.)
2460
2461
 
2461
2462
  # calculate returns
2462
2463
 
@@ -2491,7 +2492,7 @@ class DynamicsWorldModel(Module):
2491
2492
 
2492
2493
  # mean, var - todo - handle distributed
2493
2494
 
2494
- returns_mean, returns_var = returns.mean(), returns.var()
2495
+ returns_mean, returns_var = returns_for_stats.mean(), returns_for_stats.var()
2495
2496
 
2496
2497
  # ema
2497
2498
 
@@ -2694,12 +2695,22 @@ class DynamicsWorldModel(Module):
2694
2695
  return_rewards_per_frame = False,
2695
2696
  return_agent_actions = False,
2696
2697
  return_log_probs_and_values = False,
2698
+ return_for_policy_optimization = False,
2697
2699
  return_time_kv_cache = False,
2698
2700
  store_agent_embed = True,
2699
2701
  store_old_action_unembeds = True
2700
2702
 
2701
2703
  ): # (b t n d) | (b c t h w)
2702
2704
 
2705
+ # handy flag for returning generations for rl
2706
+
2707
+ if return_for_policy_optimization:
2708
+ return_agent_actions |= True
2709
+ return_log_probs_and_values |= True
2710
+ return_rewards_per_frame |= True
2711
+
2712
+ # more variables
2713
+
2703
2714
  has_proprio = self.has_proprio
2704
2715
  was_training = self.training
2705
2716
  self.eval()
@@ -2769,6 +2780,19 @@ class DynamicsWorldModel(Module):
2769
2780
 
2770
2781
  curr_time_steps = latents.shape[1]
2771
2782
 
2783
+ # determine whether to take an extra step if
2784
+ # (1) using time kv cache
2785
+ # (2) decoding anything off agent embedding (rewards, actions, etc)
2786
+
2787
+ take_extra_step = (
2788
+ use_time_kv_cache or
2789
+ return_rewards_per_frame or
2790
+ store_agent_embed or
2791
+ return_agent_actions
2792
+ )
2793
+
2794
+ # prepare noised latent / proprio inputs
2795
+
2772
2796
  noised_latent = randn((batch_size, 1, self.num_video_views, *latent_shape), device = self.device)
2773
2797
 
2774
2798
  noised_proprio = None
@@ -2776,7 +2800,10 @@ class DynamicsWorldModel(Module):
2776
2800
  if has_proprio:
2777
2801
  noised_proprio = randn((batch_size, 1, self.dim_proprio), device = self.device)
2778
2802
 
2779
- for step in range(num_steps):
2803
+ # denoising steps
2804
+
2805
+ for step in range(num_steps + int(take_extra_step)):
2806
+
2780
2807
  is_last_step = (step + 1) == num_steps
2781
2808
 
2782
2809
  signal_levels = full((batch_size, 1), step * step_size, dtype = torch.long, device = self.device)
@@ -2819,6 +2846,11 @@ class DynamicsWorldModel(Module):
2819
2846
  if use_time_kv_cache and is_last_step:
2820
2847
  time_kv_cache = next_time_kv_cache
2821
2848
 
2849
+ # early break if taking an extra step for agent embedding off cleaned latents for decoding
2850
+
2851
+ if take_extra_step and is_last_step:
2852
+ break
2853
+
2822
2854
  # maybe proprio
2823
2855
 
2824
2856
  if has_proprio:
@@ -3021,7 +3053,7 @@ class DynamicsWorldModel(Module):
3021
3053
  latent_is_noised = False,
3022
3054
  return_all_losses = False,
3023
3055
  return_intermediates = False,
3024
- add_autoregressive_action_loss = False,
3056
+ add_autoregressive_action_loss = True,
3025
3057
  update_loss_ema = None,
3026
3058
  latent_has_view_dim = False
3027
3059
  ):
@@ -3478,7 +3510,7 @@ class DynamicsWorldModel(Module):
3478
3510
 
3479
3511
  reward_losses = F.cross_entropy(reward_pred, reward_targets, reduction = 'none')
3480
3512
 
3481
- reward_losses = reward_losses.masked_fill(reward_loss_mask, 0.)
3513
+ reward_losses = reward_losses.masked_fill(~reward_loss_mask, 0.)
3482
3514
 
3483
3515
  if is_var_len:
3484
3516
  reward_loss = reward_losses[loss_mask_without_last].mean(dim = 0)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dreamer4
3
- Version: 0.0.102
3
+ Version: 0.1.1
4
4
  Summary: Dreamer 4
5
5
  Project-URL: Homepage, https://pypi.org/project/dreamer4/
6
6
  Project-URL: Repository, https://github.com/lucidrains/dreamer4
@@ -53,11 +53,75 @@ Description-Content-Type: text/markdown
53
53
 
54
54
  <img src="./dreamer4-fig2.png" width="400px"></img>
55
55
 
56
- ## Dreamer 4 (wip)
56
+ ## Dreamer 4
57
57
 
58
58
  Implementation of Danijar's [latest iteration](https://arxiv.org/abs/2509.24527v1) for his [Dreamer](https://danijar.com/project/dreamer4/) line of work
59
59
 
60
- [Temporary Discord](https://discord.gg/MkACrrkrYR)
60
+ ## Install
61
+
62
+ ```bash
63
+ $ pip install dreamer4
64
+ ```
65
+
66
+ ## Usage
67
+
68
+ ```python
69
+ import torch
70
+ from dreamer4 import VideoTokenizer, DynamicsWorldModel
71
+
72
+ # video tokenizer, learned through MAE + lpips
73
+
74
+ tokenizer = VideoTokenizer(
75
+ dim = 512,
76
+ dim_latent = 32,
77
+ patch_size = 32,
78
+ image_height = 256,
79
+ image_width = 256
80
+ )
81
+
82
+ # dynamics world model
83
+
84
+ dynamics = DynamicsWorldModel(
85
+ dim = 512,
86
+ dim_latent = 32,
87
+ video_tokenizer = tokenizer,
88
+ num_discrete_actions = 4,
89
+ num_residual_streams = 1
90
+ )
91
+
92
+ # state, action, rewards
93
+
94
+ video = torch.randn(2, 3, 10, 256, 256)
95
+ discrete_actions = torch.randint(0, 4, (2, 10, 1))
96
+ rewards = torch.randn(2, 10)
97
+
98
+ # learn dynamics / behavior cloned model
99
+
100
+ loss = dynamics(
101
+ video = video,
102
+ rewards = rewards,
103
+ discrete_actions = discrete_actions
104
+ )
105
+
106
+ loss.backward()
107
+
108
+ # do the above with much data
109
+
110
+ # then generate dreams
111
+
112
+ dreams = dynamics.generate(
113
+ 10,
114
+ batch_size = 2,
115
+ return_decoded_video = True,
116
+ return_for_policy_optimization = True
117
+ )
118
+
119
+ # learn from the dreams
120
+
121
+ actor_loss, critic_loss = dynamics.learn_from_experience(dreams)
122
+
123
+ (actor_loss + critic_loss).backward()
124
+ ```
61
125
 
62
126
  ## Citation
63
127
 
@@ -72,3 +136,5 @@ Implementation of Danijar's [latest iteration](https://arxiv.org/abs/2509.24527v
72
136
  url = {https://arxiv.org/abs/2509.24527},
73
137
  }
74
138
  ```
139
+
140
+ *the conquest of nature is to be achieved through number and measure - angels to Descartes in a dream*
@@ -0,0 +1,8 @@
1
+ dreamer4/__init__.py,sha256=Jssh1obzDRtTfBLZl36kXge1cIQlMjf_8DyjPulvKSk,183
2
+ dreamer4/dreamer4.py,sha256=_gWp08k7tf2VCUv7uqkXKZQugnqJqXPb1-o7_34SA9c,120365
3
+ dreamer4/mocks.py,sha256=TfqOB_Gq6N_GggBYwa6ZAJQx38ntlYbXZe23Ne4jshw,2502
4
+ dreamer4/trainers.py,sha256=JsnJwQJcbI_75KBTNddG6b7QVkO6LD1N_HQiVe-VnCM,15087
5
+ dreamer4-0.1.1.dist-info/METADATA,sha256=2zkBv1BHvGpb6onAnEFsKnPK2KD-0vH8K1nFDBVlpyU,4247
6
+ dreamer4-0.1.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ dreamer4-0.1.1.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ dreamer4-0.1.1.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- dreamer4/__init__.py,sha256=Jssh1obzDRtTfBLZl36kXge1cIQlMjf_8DyjPulvKSk,183
2
- dreamer4/dreamer4.py,sha256=3qeVN3qdvx7iPxA0OBXw_yy5Re6rX6FIKITH9bp6RBs,119202
3
- dreamer4/mocks.py,sha256=TfqOB_Gq6N_GggBYwa6ZAJQx38ntlYbXZe23Ne4jshw,2502
4
- dreamer4/trainers.py,sha256=JsnJwQJcbI_75KBTNddG6b7QVkO6LD1N_HQiVe-VnCM,15087
5
- dreamer4-0.0.102.dist-info/METADATA,sha256=xxVL1sFimb0azSD5sDOEzugY7rBT6oDek4YdiIS8m18,3066
6
- dreamer4-0.0.102.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- dreamer4-0.0.102.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- dreamer4-0.0.102.dist-info/RECORD,,