dreamer4 0.1.0__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dreamer4/dreamer4.py CHANGED
@@ -1179,10 +1179,11 @@ def special_token_mask(q, k, seq_len, num_tokens, special_attend_only_itself = F
1179
1179
 
1180
1180
  def block_mask_special_tokens_right(
1181
1181
  seq_len,
1182
- num_tokens
1182
+ num_tokens,
1183
+ special_attend_only_itself = False
1183
1184
  ):
1184
1185
  def inner(b, h, q, k):
1185
- return special_token_mask(q, k, seq_len, num_tokens)
1186
+ return special_token_mask(q, k, seq_len, num_tokens, special_attend_only_itself)
1186
1187
  return inner
1187
1188
 
1188
1189
  def compose_mask(mask1, mask2):
@@ -1331,6 +1332,12 @@ class Attention(Module):
1331
1332
  q = self.q_heads_rmsnorm(q)
1332
1333
  k = self.k_heads_rmsnorm(k)
1333
1334
 
1335
+ # rotary
1336
+
1337
+ if exists(rotary_pos_emb):
1338
+ q = apply_rotations(rotary_pos_emb, q)
1339
+ k = apply_rotations(rotary_pos_emb, k)
1340
+
1334
1341
  # caching
1335
1342
 
1336
1343
  if exists(kv_cache):
@@ -1338,12 +1345,6 @@ class Attention(Module):
1338
1345
  k = cat((ck, k), dim = -2)
1339
1346
  v = cat((cv, v), dim = -2)
1340
1347
 
1341
- # rotary
1342
-
1343
- if exists(rotary_pos_emb):
1344
- q = apply_rotations(rotary_pos_emb, q)
1345
- k = apply_rotations(rotary_pos_emb, k)
1346
-
1347
1348
  # attention
1348
1349
 
1349
1350
  attend_fn = default(attend_fn, naive_attend)
@@ -1493,7 +1494,8 @@ class AxialSpaceTimeTransformer(Module):
1493
1494
 
1494
1495
  # attend functions for space and time
1495
1496
 
1496
- use_flex = exists(flex_attention) and tokens.is_cuda
1497
+ has_kv_cache = exists(kv_cache)
1498
+ use_flex = exists(flex_attention) and tokens.is_cuda and not has_kv_cache # KV cache shape breaks flex attention TODO: Fix
1497
1499
 
1498
1500
  attend_kwargs = dict(use_flex = use_flex, softclamp_value = self.attn_softclamp_value, special_attend_only_itself = self.special_attend_only_itself, device = device)
1499
1501
 
@@ -1505,14 +1507,12 @@ class AxialSpaceTimeTransformer(Module):
1505
1507
 
1506
1508
  time_attn_kv_caches = []
1507
1509
 
1508
- has_kv_cache = exists(kv_cache)
1509
-
1510
1510
 
1511
1511
  if has_kv_cache:
1512
1512
  past_tokens, tokens = tokens[:, :-1], tokens[:, -1:]
1513
1513
 
1514
1514
  rotary_seq_len = 1
1515
- rotary_pos_offset = past_tokens.shape[-2]
1515
+ rotary_pos_offset = past_tokens.shape[1]
1516
1516
  else:
1517
1517
  rotary_seq_len = time
1518
1518
  rotary_pos_offset = 0
@@ -1687,6 +1687,7 @@ class VideoTokenizer(Module):
1687
1687
  time_block_every = time_block_every,
1688
1688
  num_special_spatial_tokens = num_latent_tokens,
1689
1689
  num_residual_streams = num_residual_streams,
1690
+ special_attend_only_itself = True,
1690
1691
  final_norm = True
1691
1692
  )
1692
1693
 
@@ -1847,7 +1848,7 @@ class VideoTokenizer(Module):
1847
1848
 
1848
1849
  losses = (recon_loss, lpips_loss)
1849
1850
 
1850
- return total_loss, TokenizerLosses(losses)
1851
+ return total_loss, TokenizerLosses(*losses)
1851
1852
 
1852
1853
  # dynamics model, axial space-time transformer
1853
1854
 
@@ -2104,7 +2105,7 @@ class DynamicsWorldModel(Module):
2104
2105
 
2105
2106
  self.ppo_eps_clip = ppo_eps_clip
2106
2107
  self.value_clip = value_clip
2107
- self.policy_entropy_weight = value_clip
2108
+ self.policy_entropy_weight = policy_entropy_weight
2108
2109
 
2109
2110
  # pmpo related
2110
2111
 
@@ -2127,7 +2128,7 @@ class DynamicsWorldModel(Module):
2127
2128
  self.flow_loss_normalizer = LossNormalizer(1)
2128
2129
  self.reward_loss_normalizer = LossNormalizer(multi_token_pred_len)
2129
2130
  self.discrete_actions_loss_normalizer = LossNormalizer(multi_token_pred_len) if num_discrete_actions > 0 else None
2130
- self.continuous_actions_loss_normalizer = LossNormalizer(multi_token_pred_len) if num_discrete_actions > 0 else None
2131
+ self.continuous_actions_loss_normalizer = LossNormalizer(multi_token_pred_len) if num_continuous_actions > 0 else None
2131
2132
 
2132
2133
  self.latent_flow_loss_weight = latent_flow_loss_weight
2133
2134
 
@@ -2358,6 +2359,9 @@ class DynamicsWorldModel(Module):
2358
2359
  elif len(env_step_out) == 4:
2359
2360
  next_frame, reward, terminated, truncated = env_step_out
2360
2361
 
2362
+ elif len(env_step_out) == 5:
2363
+ next_frame, reward, terminated, truncated, info = env_step_out
2364
+
2361
2365
  # update episode lens
2362
2366
 
2363
2367
  episode_lens = torch.where(is_terminated, episode_lens, episode_lens + 1)
@@ -2456,8 +2460,8 @@ class DynamicsWorldModel(Module):
2456
2460
  if exists(experience.lens):
2457
2461
  mask_for_gae = lens_to_mask(experience.lens, time)
2458
2462
 
2459
- rewards = rewards.masked_fill(mask_for_gae, 0.)
2460
- old_values = old_values.masked_fill(mask_for_gae, 0.)
2463
+ rewards = rewards.masked_fill(~mask_for_gae, 0.)
2464
+ old_values = old_values.masked_fill(~mask_for_gae, 0.)
2461
2465
 
2462
2466
  # calculate returns
2463
2467
 
@@ -2492,7 +2496,7 @@ class DynamicsWorldModel(Module):
2492
2496
 
2493
2497
  # mean, var - todo - handle distributed
2494
2498
 
2495
- returns_mean, returns_var = returns.mean(), returns.var()
2499
+ returns_mean, returns_var = returns_for_stats.mean(), returns_for_stats.var()
2496
2500
 
2497
2501
  # ema
2498
2502
 
@@ -3085,8 +3089,8 @@ class DynamicsWorldModel(Module):
3085
3089
  if latents.ndim == 4:
3086
3090
  latents = rearrange(latents, 'b t v d -> b t v 1 d') # 1 latent edge case
3087
3091
 
3088
- assert latents.shape[-2:] == self.latent_shape
3089
- assert latents.shape[2] == self.num_video_views
3092
+ assert latents.shape[-2:] == self.latent_shape, f'latents must have shape {self.latent_shape}, got {latents.shape[-2:]}'
3093
+ assert latents.shape[2] == self.num_video_views, f'latents must have {self.num_video_views} views, got {latents.shape[2]}'
3090
3094
 
3091
3095
  # variables
3092
3096
 
@@ -3510,7 +3514,7 @@ class DynamicsWorldModel(Module):
3510
3514
 
3511
3515
  reward_losses = F.cross_entropy(reward_pred, reward_targets, reduction = 'none')
3512
3516
 
3513
- reward_losses = reward_losses.masked_fill(reward_loss_mask, 0.)
3517
+ reward_losses = reward_losses.masked_fill(~reward_loss_mask, 0.)
3514
3518
 
3515
3519
  if is_var_len:
3516
3520
  reward_loss = reward_losses[loss_mask_without_last].mean(dim = 0)
@@ -3554,7 +3558,7 @@ class DynamicsWorldModel(Module):
3554
3558
  discrete_mask = rearrange(discrete_mask, 'b t mtp -> mtp b t')
3555
3559
 
3556
3560
  if exists(continuous_actions):
3557
- continuous_action_targets, continuous_mask = create_multi_token_prediction_targets(discrete_actions, self.multi_token_pred_len)
3561
+ continuous_action_targets, continuous_mask = create_multi_token_prediction_targets(continuous_actions, self.multi_token_pred_len)
3558
3562
  continuous_action_targets = rearrange(continuous_action_targets, 'b t mtp ... -> mtp b t ...')
3559
3563
  continuous_mask = rearrange(continuous_mask, 'b t mtp -> mtp b t')
3560
3564
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dreamer4
3
- Version: 0.1.0
3
+ Version: 0.1.4
4
4
  Summary: Dreamer 4
5
5
  Project-URL: Homepage, https://pypi.org/project/dreamer4/
6
6
  Project-URL: Repository, https://github.com/lucidrains/dreamer4
@@ -57,10 +57,16 @@ Description-Content-Type: text/markdown
57
57
 
58
58
  Implementation of Danijar's [latest iteration](https://arxiv.org/abs/2509.24527v1) for his [Dreamer](https://danijar.com/project/dreamer4/) line of work
59
59
 
60
+ [Discord channel](https://discord.gg/ab4BEk3W) for collaborating with other researchers interested in this work
61
+
62
+ ## Appreciation
63
+
64
+ - [@dirkmcpherson](https://github.com/dirkmcpherson) for fixes to typo errors and unpassed arguments!
65
+
60
66
  ## Install
61
67
 
62
68
  ```bash
63
- $ pip install dreamer4-pytorch
69
+ $ pip install dreamer4
64
70
  ```
65
71
 
66
72
  ## Usage
@@ -79,9 +85,16 @@ tokenizer = VideoTokenizer(
79
85
  image_width = 256
80
86
  )
81
87
 
88
+ video = torch.randn(2, 3, 10, 256, 256)
89
+
90
+ # learn the tokenizer
91
+
92
+ loss = tokenizer(video)
93
+ loss.backward() # ler
94
+
82
95
  # dynamics world model
83
96
 
84
- dynamics = DynamicsWorldModel(
97
+ world_model = DynamicsWorldModel(
85
98
  dim = 512,
86
99
  dim_latent = 32,
87
100
  video_tokenizer = tokenizer,
@@ -97,7 +110,7 @@ rewards = torch.randn(2, 10)
97
110
 
98
111
  # learn dynamics / behavior cloned model
99
112
 
100
- loss = dynamics(
113
+ loss = world_model(
101
114
  video = video,
102
115
  rewards = rewards,
103
116
  discrete_actions = discrete_actions
@@ -109,7 +122,7 @@ loss.backward()
109
122
 
110
123
  # then generate dreams
111
124
 
112
- dreams = dynamics.generate(
125
+ dreams = world_model.generate(
113
126
  10,
114
127
  batch_size = 2,
115
128
  return_decoded_video = True,
@@ -118,7 +131,19 @@ dreams = dynamics.generate(
118
131
 
119
132
  # learn from the dreams
120
133
 
121
- actor_loss, critic_loss = dynamics.learn_from_experience(dreams)
134
+ actor_loss, critic_loss = world_model.learn_from_experience(dreams)
135
+
136
+ (actor_loss + critic_loss).backward()
137
+
138
+ # learn from environment
139
+
140
+ from dreamer4.mocks import MockEnv
141
+
142
+ mock_env = MockEnv((256, 256), vectorized = True, num_envs = 4)
143
+
144
+ experience = world_model.interact_with_env(mock_env, max_timesteps = 8, env_is_vectorized = True)
145
+
146
+ actor_loss, critic_loss = world_model.learn_from_experience(experience)
122
147
 
123
148
  (actor_loss + critic_loss).backward()
124
149
  ```
@@ -137,4 +162,4 @@ actor_loss, critic_loss = dynamics.learn_from_experience(dreams)
137
162
  }
138
163
  ```
139
164
 
140
- *the conquest of nature is to be achieved through number and measure* - angels to Descartes, in a dream, the story goes.
165
+ *the conquest of nature is to be achieved through number and measure - angels to Descartes in a dream*
@@ -0,0 +1,8 @@
1
+ dreamer4/__init__.py,sha256=Jssh1obzDRtTfBLZl36kXge1cIQlMjf_8DyjPulvKSk,183
2
+ dreamer4/dreamer4.py,sha256=ghestMgz7B1oEqBRR0XkkdWe0kkh7bshhzmi6-n-XIs,120790
3
+ dreamer4/mocks.py,sha256=TfqOB_Gq6N_GggBYwa6ZAJQx38ntlYbXZe23Ne4jshw,2502
4
+ dreamer4/trainers.py,sha256=JsnJwQJcbI_75KBTNddG6b7QVkO6LD1N_HQiVe-VnCM,15087
5
+ dreamer4-0.1.4.dist-info/METADATA,sha256=GkzuqKtNJJCSh5FycWJOr49253_w926biJkSz9ic4TQ,4941
6
+ dreamer4-0.1.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ dreamer4-0.1.4.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ dreamer4-0.1.4.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- dreamer4/__init__.py,sha256=Jssh1obzDRtTfBLZl36kXge1cIQlMjf_8DyjPulvKSk,183
2
- dreamer4/dreamer4.py,sha256=YB724hMjBYDNhApo2x_52oXIeH5GGQo8Q2pB2lkCq_s,120297
3
- dreamer4/mocks.py,sha256=TfqOB_Gq6N_GggBYwa6ZAJQx38ntlYbXZe23Ne4jshw,2502
4
- dreamer4/trainers.py,sha256=JsnJwQJcbI_75KBTNddG6b7QVkO6LD1N_HQiVe-VnCM,15087
5
- dreamer4-0.1.0.dist-info/METADATA,sha256=kDq66Il_WDNKR2NP9wrVu3fMUIVW-pWySq3CP2ANZ2s,4273
6
- dreamer4-0.1.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- dreamer4-0.1.0.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- dreamer4-0.1.0.dist-info/RECORD,,