dreamer4 0.1.2__tar.gz → 0.1.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dreamer4 might be problematic. Click here for more details.

@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dreamer4
3
- Version: 0.1.2
3
+ Version: 0.1.4
4
4
  Summary: Dreamer 4
5
5
  Project-URL: Homepage, https://pypi.org/project/dreamer4/
6
6
  Project-URL: Repository, https://github.com/lucidrains/dreamer4
@@ -57,6 +57,12 @@ Description-Content-Type: text/markdown
57
57
 
58
58
  Implementation of Danijar's [latest iteration](https://arxiv.org/abs/2509.24527v1) for his [Dreamer](https://danijar.com/project/dreamer4/) line of work
59
59
 
60
+ [Discord channel](https://discord.gg/ab4BEk3W) for collaborating with other researchers interested in this work
61
+
62
+ ## Appreciation
63
+
64
+ - [@dirkmcpherson](https://github.com/dirkmcpherson) for fixes to typo errors and unpassed arguments!
65
+
60
66
  ## Install
61
67
 
62
68
  ```bash
@@ -79,9 +85,16 @@ tokenizer = VideoTokenizer(
79
85
  image_width = 256
80
86
  )
81
87
 
88
+ video = torch.randn(2, 3, 10, 256, 256)
89
+
90
+ # learn the tokenizer
91
+
92
+ loss = tokenizer(video)
93
+ loss.backward() # ler
94
+
82
95
  # dynamics world model
83
96
 
84
- dynamics = DynamicsWorldModel(
97
+ world_model = DynamicsWorldModel(
85
98
  dim = 512,
86
99
  dim_latent = 32,
87
100
  video_tokenizer = tokenizer,
@@ -97,7 +110,7 @@ rewards = torch.randn(2, 10)
97
110
 
98
111
  # learn dynamics / behavior cloned model
99
112
 
100
- loss = dynamics(
113
+ loss = world_model(
101
114
  video = video,
102
115
  rewards = rewards,
103
116
  discrete_actions = discrete_actions
@@ -109,7 +122,7 @@ loss.backward()
109
122
 
110
123
  # then generate dreams
111
124
 
112
- dreams = dynamics.generate(
125
+ dreams = world_model.generate(
113
126
  10,
114
127
  batch_size = 2,
115
128
  return_decoded_video = True,
@@ -118,7 +131,19 @@ dreams = dynamics.generate(
118
131
 
119
132
  # learn from the dreams
120
133
 
121
- actor_loss, critic_loss = dynamics.learn_from_experience(dreams)
134
+ actor_loss, critic_loss = world_model.learn_from_experience(dreams)
135
+
136
+ (actor_loss + critic_loss).backward()
137
+
138
+ # learn from environment
139
+
140
+ from dreamer4.mocks import MockEnv
141
+
142
+ mock_env = MockEnv((256, 256), vectorized = True, num_envs = 4)
143
+
144
+ experience = world_model.interact_with_env(mock_env, max_timesteps = 8, env_is_vectorized = True)
145
+
146
+ actor_loss, critic_loss = world_model.learn_from_experience(experience)
122
147
 
123
148
  (actor_loss + critic_loss).backward()
124
149
  ```
@@ -4,6 +4,12 @@
4
4
 
5
5
  Implementation of Danijar's [latest iteration](https://arxiv.org/abs/2509.24527v1) for his [Dreamer](https://danijar.com/project/dreamer4/) line of work
6
6
 
7
+ [Discord channel](https://discord.gg/ab4BEk3W) for collaborating with other researchers interested in this work
8
+
9
+ ## Appreciation
10
+
11
+ - [@dirkmcpherson](https://github.com/dirkmcpherson) for fixes to typo errors and unpassed arguments!
12
+
7
13
  ## Install
8
14
 
9
15
  ```bash
@@ -26,9 +32,16 @@ tokenizer = VideoTokenizer(
26
32
  image_width = 256
27
33
  )
28
34
 
35
+ video = torch.randn(2, 3, 10, 256, 256)
36
+
37
+ # learn the tokenizer
38
+
39
+ loss = tokenizer(video)
40
+ loss.backward() # ler
41
+
29
42
  # dynamics world model
30
43
 
31
- dynamics = DynamicsWorldModel(
44
+ world_model = DynamicsWorldModel(
32
45
  dim = 512,
33
46
  dim_latent = 32,
34
47
  video_tokenizer = tokenizer,
@@ -44,7 +57,7 @@ rewards = torch.randn(2, 10)
44
57
 
45
58
  # learn dynamics / behavior cloned model
46
59
 
47
- loss = dynamics(
60
+ loss = world_model(
48
61
  video = video,
49
62
  rewards = rewards,
50
63
  discrete_actions = discrete_actions
@@ -56,7 +69,7 @@ loss.backward()
56
69
 
57
70
  # then generate dreams
58
71
 
59
- dreams = dynamics.generate(
72
+ dreams = world_model.generate(
60
73
  10,
61
74
  batch_size = 2,
62
75
  return_decoded_video = True,
@@ -65,7 +78,19 @@ dreams = dynamics.generate(
65
78
 
66
79
  # learn from the dreams
67
80
 
68
- actor_loss, critic_loss = dynamics.learn_from_experience(dreams)
81
+ actor_loss, critic_loss = world_model.learn_from_experience(dreams)
82
+
83
+ (actor_loss + critic_loss).backward()
84
+
85
+ # learn from environment
86
+
87
+ from dreamer4.mocks import MockEnv
88
+
89
+ mock_env = MockEnv((256, 256), vectorized = True, num_envs = 4)
90
+
91
+ experience = world_model.interact_with_env(mock_env, max_timesteps = 8, env_is_vectorized = True)
92
+
93
+ actor_loss, critic_loss = world_model.learn_from_experience(experience)
69
94
 
70
95
  (actor_loss + critic_loss).backward()
71
96
  ```
@@ -1179,10 +1179,11 @@ def special_token_mask(q, k, seq_len, num_tokens, special_attend_only_itself = F
1179
1179
 
1180
1180
  def block_mask_special_tokens_right(
1181
1181
  seq_len,
1182
- num_tokens
1182
+ num_tokens,
1183
+ special_attend_only_itself = False
1183
1184
  ):
1184
1185
  def inner(b, h, q, k):
1185
- return special_token_mask(q, k, seq_len, num_tokens)
1186
+ return special_token_mask(q, k, seq_len, num_tokens, special_attend_only_itself)
1186
1187
  return inner
1187
1188
 
1188
1189
  def compose_mask(mask1, mask2):
@@ -1493,7 +1494,8 @@ class AxialSpaceTimeTransformer(Module):
1493
1494
 
1494
1495
  # attend functions for space and time
1495
1496
 
1496
- use_flex = exists(flex_attention) and tokens.is_cuda
1497
+ has_kv_cache = exists(kv_cache)
1498
+ use_flex = exists(flex_attention) and tokens.is_cuda and not has_kv_cache # KV cache shape breaks flex attention TODO: Fix
1497
1499
 
1498
1500
  attend_kwargs = dict(use_flex = use_flex, softclamp_value = self.attn_softclamp_value, special_attend_only_itself = self.special_attend_only_itself, device = device)
1499
1501
 
@@ -1505,7 +1507,6 @@ class AxialSpaceTimeTransformer(Module):
1505
1507
 
1506
1508
  time_attn_kv_caches = []
1507
1509
 
1508
- has_kv_cache = exists(kv_cache)
1509
1510
 
1510
1511
  if has_kv_cache:
1511
1512
  past_tokens, tokens = tokens[:, :-1], tokens[:, -1:]
@@ -1847,7 +1848,7 @@ class VideoTokenizer(Module):
1847
1848
 
1848
1849
  losses = (recon_loss, lpips_loss)
1849
1850
 
1850
- return total_loss, TokenizerLosses(losses)
1851
+ return total_loss, TokenizerLosses(*losses)
1851
1852
 
1852
1853
  # dynamics model, axial space-time transformer
1853
1854
 
@@ -2104,7 +2105,7 @@ class DynamicsWorldModel(Module):
2104
2105
 
2105
2106
  self.ppo_eps_clip = ppo_eps_clip
2106
2107
  self.value_clip = value_clip
2107
- self.policy_entropy_weight = value_clip
2108
+ self.policy_entropy_weight = policy_entropy_weight
2108
2109
 
2109
2110
  # pmpo related
2110
2111
 
@@ -2127,7 +2128,7 @@ class DynamicsWorldModel(Module):
2127
2128
  self.flow_loss_normalizer = LossNormalizer(1)
2128
2129
  self.reward_loss_normalizer = LossNormalizer(multi_token_pred_len)
2129
2130
  self.discrete_actions_loss_normalizer = LossNormalizer(multi_token_pred_len) if num_discrete_actions > 0 else None
2130
- self.continuous_actions_loss_normalizer = LossNormalizer(multi_token_pred_len) if num_discrete_actions > 0 else None
2131
+ self.continuous_actions_loss_normalizer = LossNormalizer(multi_token_pred_len) if num_continuous_actions > 0 else None
2131
2132
 
2132
2133
  self.latent_flow_loss_weight = latent_flow_loss_weight
2133
2134
 
@@ -2358,6 +2359,9 @@ class DynamicsWorldModel(Module):
2358
2359
  elif len(env_step_out) == 4:
2359
2360
  next_frame, reward, terminated, truncated = env_step_out
2360
2361
 
2362
+ elif len(env_step_out) == 5:
2363
+ next_frame, reward, terminated, truncated, info = env_step_out
2364
+
2361
2365
  # update episode lens
2362
2366
 
2363
2367
  episode_lens = torch.where(is_terminated, episode_lens, episode_lens + 1)
@@ -3085,8 +3089,8 @@ class DynamicsWorldModel(Module):
3085
3089
  if latents.ndim == 4:
3086
3090
  latents = rearrange(latents, 'b t v d -> b t v 1 d') # 1 latent edge case
3087
3091
 
3088
- assert latents.shape[-2:] == self.latent_shape
3089
- assert latents.shape[2] == self.num_video_views
3092
+ assert latents.shape[-2:] == self.latent_shape, f'latents must have shape {self.latent_shape}, got {latents.shape[-2:]}'
3093
+ assert latents.shape[2] == self.num_video_views, f'latents must have {self.num_video_views} views, got {latents.shape[2]}'
3090
3094
 
3091
3095
  # variables
3092
3096
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "dreamer4"
3
- version = "0.1.2"
3
+ version = "0.1.4"
4
4
  description = "Dreamer 4"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes