dreamer4 0.1.2__tar.gz → 0.1.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dreamer4
3
- Version: 0.1.2
3
+ Version: 0.1.5
4
4
  Summary: Dreamer 4
5
5
  Project-URL: Homepage, https://pypi.org/project/dreamer4/
6
6
  Project-URL: Repository, https://github.com/lucidrains/dreamer4
@@ -57,6 +57,12 @@ Description-Content-Type: text/markdown
57
57
 
58
58
  Implementation of Danijar's [latest iteration](https://arxiv.org/abs/2509.24527v1) for his [Dreamer](https://danijar.com/project/dreamer4/) line of work
59
59
 
60
+ [Discord channel](https://discord.gg/ab4BEk3W) for collaborating with other researchers interested in this work
61
+
62
+ ## Appreciation
63
+
64
+ - [@dirkmcpherson](https://github.com/dirkmcpherson) for fixes to typo errors and unpassed arguments!
65
+
60
66
  ## Install
61
67
 
62
68
  ```bash
@@ -79,9 +85,16 @@ tokenizer = VideoTokenizer(
79
85
  image_width = 256
80
86
  )
81
87
 
88
+ video = torch.randn(2, 3, 10, 256, 256)
89
+
90
+ # learn the tokenizer
91
+
92
+ loss = tokenizer(video)
93
+ loss.backward() # ler
94
+
82
95
  # dynamics world model
83
96
 
84
- dynamics = DynamicsWorldModel(
97
+ world_model = DynamicsWorldModel(
85
98
  dim = 512,
86
99
  dim_latent = 32,
87
100
  video_tokenizer = tokenizer,
@@ -97,7 +110,7 @@ rewards = torch.randn(2, 10)
97
110
 
98
111
  # learn dynamics / behavior cloned model
99
112
 
100
- loss = dynamics(
113
+ loss = world_model(
101
114
  video = video,
102
115
  rewards = rewards,
103
116
  discrete_actions = discrete_actions
@@ -109,7 +122,7 @@ loss.backward()
109
122
 
110
123
  # then generate dreams
111
124
 
112
- dreams = dynamics.generate(
125
+ dreams = world_model.generate(
113
126
  10,
114
127
  batch_size = 2,
115
128
  return_decoded_video = True,
@@ -118,7 +131,19 @@ dreams = dynamics.generate(
118
131
 
119
132
  # learn from the dreams
120
133
 
121
- actor_loss, critic_loss = dynamics.learn_from_experience(dreams)
134
+ actor_loss, critic_loss = world_model.learn_from_experience(dreams)
135
+
136
+ (actor_loss + critic_loss).backward()
137
+
138
+ # learn from environment
139
+
140
+ from dreamer4.mocks import MockEnv
141
+
142
+ mock_env = MockEnv((256, 256), vectorized = True, num_envs = 4)
143
+
144
+ experience = world_model.interact_with_env(mock_env, max_timesteps = 8, env_is_vectorized = True)
145
+
146
+ actor_loss, critic_loss = world_model.learn_from_experience(experience)
122
147
 
123
148
  (actor_loss + critic_loss).backward()
124
149
  ```
@@ -4,6 +4,12 @@
4
4
 
5
5
  Implementation of Danijar's [latest iteration](https://arxiv.org/abs/2509.24527v1) for his [Dreamer](https://danijar.com/project/dreamer4/) line of work
6
6
 
7
+ [Discord channel](https://discord.gg/ab4BEk3W) for collaborating with other researchers interested in this work
8
+
9
+ ## Appreciation
10
+
11
+ - [@dirkmcpherson](https://github.com/dirkmcpherson) for fixes to typo errors and unpassed arguments!
12
+
7
13
  ## Install
8
14
 
9
15
  ```bash
@@ -26,9 +32,16 @@ tokenizer = VideoTokenizer(
26
32
  image_width = 256
27
33
  )
28
34
 
35
+ video = torch.randn(2, 3, 10, 256, 256)
36
+
37
+ # learn the tokenizer
38
+
39
+ loss = tokenizer(video)
40
+ loss.backward() # ler
41
+
29
42
  # dynamics world model
30
43
 
31
- dynamics = DynamicsWorldModel(
44
+ world_model = DynamicsWorldModel(
32
45
  dim = 512,
33
46
  dim_latent = 32,
34
47
  video_tokenizer = tokenizer,
@@ -44,7 +57,7 @@ rewards = torch.randn(2, 10)
44
57
 
45
58
  # learn dynamics / behavior cloned model
46
59
 
47
- loss = dynamics(
60
+ loss = world_model(
48
61
  video = video,
49
62
  rewards = rewards,
50
63
  discrete_actions = discrete_actions
@@ -56,7 +69,7 @@ loss.backward()
56
69
 
57
70
  # then generate dreams
58
71
 
59
- dreams = dynamics.generate(
72
+ dreams = world_model.generate(
60
73
  10,
61
74
  batch_size = 2,
62
75
  return_decoded_video = True,
@@ -65,7 +78,19 @@ dreams = dynamics.generate(
65
78
 
66
79
  # learn from the dreams
67
80
 
68
- actor_loss, critic_loss = dynamics.learn_from_experience(dreams)
81
+ actor_loss, critic_loss = world_model.learn_from_experience(dreams)
82
+
83
+ (actor_loss + critic_loss).backward()
84
+
85
+ # learn from environment
86
+
87
+ from dreamer4.mocks import MockEnv
88
+
89
+ mock_env = MockEnv((256, 256), vectorized = True, num_envs = 4)
90
+
91
+ experience = world_model.interact_with_env(mock_env, max_timesteps = 8, env_is_vectorized = True)
92
+
93
+ actor_loss, critic_loss = world_model.learn_from_experience(experience)
69
94
 
70
95
  (actor_loss + critic_loss).backward()
71
96
  ```
@@ -14,7 +14,7 @@ from torch.nested import nested_tensor
14
14
  from torch.distributions import Normal, kl
15
15
  from torch.nn import Module, ModuleList, Embedding, Parameter, Sequential, Linear, RMSNorm, Identity
16
16
  from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, full, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange
17
- from torch.utils._pytree import tree_flatten, tree_unflatten
17
+ from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
18
18
 
19
19
  import torchvision
20
20
  from torchvision.models import VGG16_Weights
@@ -91,6 +91,14 @@ class Experience:
91
91
  agent_index: int = 0
92
92
  is_from_world_model: bool = True
93
93
 
94
+ def cpu(self):
95
+ return self.to(torch.device('cpu'))
96
+
97
+ def to(self, device):
98
+ experience_dict = asdict(self)
99
+ experience_dict = tree_map(lambda t: t.to(device) if is_tensor(t) else t, experience_dict)
100
+ return Experience(**experience_dict)
101
+
94
102
  def combine_experiences(
95
103
  exps: list[Experiences]
96
104
  ) -> Experience:
@@ -1179,10 +1187,11 @@ def special_token_mask(q, k, seq_len, num_tokens, special_attend_only_itself = F
1179
1187
 
1180
1188
  def block_mask_special_tokens_right(
1181
1189
  seq_len,
1182
- num_tokens
1190
+ num_tokens,
1191
+ special_attend_only_itself = False
1183
1192
  ):
1184
1193
  def inner(b, h, q, k):
1185
- return special_token_mask(q, k, seq_len, num_tokens)
1194
+ return special_token_mask(q, k, seq_len, num_tokens, special_attend_only_itself)
1186
1195
  return inner
1187
1196
 
1188
1197
  def compose_mask(mask1, mask2):
@@ -1493,7 +1502,8 @@ class AxialSpaceTimeTransformer(Module):
1493
1502
 
1494
1503
  # attend functions for space and time
1495
1504
 
1496
- use_flex = exists(flex_attention) and tokens.is_cuda
1505
+ has_kv_cache = exists(kv_cache)
1506
+ use_flex = exists(flex_attention) and tokens.is_cuda and not has_kv_cache # KV cache shape breaks flex attention TODO: Fix
1497
1507
 
1498
1508
  attend_kwargs = dict(use_flex = use_flex, softclamp_value = self.attn_softclamp_value, special_attend_only_itself = self.special_attend_only_itself, device = device)
1499
1509
 
@@ -1505,7 +1515,6 @@ class AxialSpaceTimeTransformer(Module):
1505
1515
 
1506
1516
  time_attn_kv_caches = []
1507
1517
 
1508
- has_kv_cache = exists(kv_cache)
1509
1518
 
1510
1519
  if has_kv_cache:
1511
1520
  past_tokens, tokens = tokens[:, :-1], tokens[:, -1:]
@@ -1847,7 +1856,7 @@ class VideoTokenizer(Module):
1847
1856
 
1848
1857
  losses = (recon_loss, lpips_loss)
1849
1858
 
1850
- return total_loss, TokenizerLosses(losses)
1859
+ return total_loss, TokenizerLosses(*losses)
1851
1860
 
1852
1861
  # dynamics model, axial space-time transformer
1853
1862
 
@@ -2104,7 +2113,7 @@ class DynamicsWorldModel(Module):
2104
2113
 
2105
2114
  self.ppo_eps_clip = ppo_eps_clip
2106
2115
  self.value_clip = value_clip
2107
- self.policy_entropy_weight = value_clip
2116
+ self.policy_entropy_weight = policy_entropy_weight
2108
2117
 
2109
2118
  # pmpo related
2110
2119
 
@@ -2127,7 +2136,7 @@ class DynamicsWorldModel(Module):
2127
2136
  self.flow_loss_normalizer = LossNormalizer(1)
2128
2137
  self.reward_loss_normalizer = LossNormalizer(multi_token_pred_len)
2129
2138
  self.discrete_actions_loss_normalizer = LossNormalizer(multi_token_pred_len) if num_discrete_actions > 0 else None
2130
- self.continuous_actions_loss_normalizer = LossNormalizer(multi_token_pred_len) if num_discrete_actions > 0 else None
2139
+ self.continuous_actions_loss_normalizer = LossNormalizer(multi_token_pred_len) if num_continuous_actions > 0 else None
2131
2140
 
2132
2141
  self.latent_flow_loss_weight = latent_flow_loss_weight
2133
2142
 
@@ -2358,6 +2367,9 @@ class DynamicsWorldModel(Module):
2358
2367
  elif len(env_step_out) == 4:
2359
2368
  next_frame, reward, terminated, truncated = env_step_out
2360
2369
 
2370
+ elif len(env_step_out) == 5:
2371
+ next_frame, reward, terminated, truncated, info = env_step_out
2372
+
2361
2373
  # update episode lens
2362
2374
 
2363
2375
  episode_lens = torch.where(is_terminated, episode_lens, episode_lens + 1)
@@ -2431,6 +2443,8 @@ class DynamicsWorldModel(Module):
2431
2443
  ):
2432
2444
  assert isinstance(experience, Experience)
2433
2445
 
2446
+ experience = experience.to(self.device)
2447
+
2434
2448
  latents = experience.latents
2435
2449
  actions = experience.actions
2436
2450
  old_log_probs = experience.log_probs
@@ -3085,8 +3099,8 @@ class DynamicsWorldModel(Module):
3085
3099
  if latents.ndim == 4:
3086
3100
  latents = rearrange(latents, 'b t v d -> b t v 1 d') # 1 latent edge case
3087
3101
 
3088
- assert latents.shape[-2:] == self.latent_shape
3089
- assert latents.shape[2] == self.num_video_views
3102
+ assert latents.shape[-2:] == self.latent_shape, f'latents must have shape {self.latent_shape}, got {latents.shape[-2:]}'
3103
+ assert latents.shape[2] == self.num_video_views, f'latents must have {self.num_video_views} views, got {latents.shape[2]}'
3090
3104
 
3091
3105
  # variables
3092
3106
 
@@ -528,7 +528,7 @@ class SimTrainer(Module):
528
528
 
529
529
  total_experience += num_experience
530
530
 
531
- experiences.append(experience)
531
+ experiences.append(experience.cpu())
532
532
 
533
533
  combined_experiences = combine_experiences(experiences)
534
534
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "dreamer4"
3
- version = "0.1.2"
3
+ version = "0.1.5"
4
4
  description = "Dreamer 4"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -680,6 +680,12 @@ def test_online_rl(
680
680
 
681
681
  combined_experience = combine_experiences([one_experience, another_experience])
682
682
 
683
+ # quick test moving the experience to different devices
684
+
685
+ if torch.cuda.is_available():
686
+ combined_experience = combined_experience.to(torch.device('cuda'))
687
+ combined_experience = combined_experience.to(world_model_and_policy.device)
688
+
683
689
  if store_agent_embed:
684
690
  assert exists(combined_experience.agent_embed)
685
691
 
File without changes
File without changes
File without changes
File without changes
File without changes