dreamer4 0.1.0__py3-none-any.whl → 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dreamer4/dreamer4.py +26 -22
- {dreamer4-0.1.0.dist-info → dreamer4-0.1.4.dist-info}/METADATA +32 -7
- dreamer4-0.1.4.dist-info/RECORD +8 -0
- dreamer4-0.1.0.dist-info/RECORD +0 -8
- {dreamer4-0.1.0.dist-info → dreamer4-0.1.4.dist-info}/WHEEL +0 -0
- {dreamer4-0.1.0.dist-info → dreamer4-0.1.4.dist-info}/licenses/LICENSE +0 -0
dreamer4/dreamer4.py
CHANGED
|
@@ -1179,10 +1179,11 @@ def special_token_mask(q, k, seq_len, num_tokens, special_attend_only_itself = F
|
|
|
1179
1179
|
|
|
1180
1180
|
def block_mask_special_tokens_right(
|
|
1181
1181
|
seq_len,
|
|
1182
|
-
num_tokens
|
|
1182
|
+
num_tokens,
|
|
1183
|
+
special_attend_only_itself = False
|
|
1183
1184
|
):
|
|
1184
1185
|
def inner(b, h, q, k):
|
|
1185
|
-
return special_token_mask(q, k, seq_len, num_tokens)
|
|
1186
|
+
return special_token_mask(q, k, seq_len, num_tokens, special_attend_only_itself)
|
|
1186
1187
|
return inner
|
|
1187
1188
|
|
|
1188
1189
|
def compose_mask(mask1, mask2):
|
|
@@ -1331,6 +1332,12 @@ class Attention(Module):
|
|
|
1331
1332
|
q = self.q_heads_rmsnorm(q)
|
|
1332
1333
|
k = self.k_heads_rmsnorm(k)
|
|
1333
1334
|
|
|
1335
|
+
# rotary
|
|
1336
|
+
|
|
1337
|
+
if exists(rotary_pos_emb):
|
|
1338
|
+
q = apply_rotations(rotary_pos_emb, q)
|
|
1339
|
+
k = apply_rotations(rotary_pos_emb, k)
|
|
1340
|
+
|
|
1334
1341
|
# caching
|
|
1335
1342
|
|
|
1336
1343
|
if exists(kv_cache):
|
|
@@ -1338,12 +1345,6 @@ class Attention(Module):
|
|
|
1338
1345
|
k = cat((ck, k), dim = -2)
|
|
1339
1346
|
v = cat((cv, v), dim = -2)
|
|
1340
1347
|
|
|
1341
|
-
# rotary
|
|
1342
|
-
|
|
1343
|
-
if exists(rotary_pos_emb):
|
|
1344
|
-
q = apply_rotations(rotary_pos_emb, q)
|
|
1345
|
-
k = apply_rotations(rotary_pos_emb, k)
|
|
1346
|
-
|
|
1347
1348
|
# attention
|
|
1348
1349
|
|
|
1349
1350
|
attend_fn = default(attend_fn, naive_attend)
|
|
@@ -1493,7 +1494,8 @@ class AxialSpaceTimeTransformer(Module):
|
|
|
1493
1494
|
|
|
1494
1495
|
# attend functions for space and time
|
|
1495
1496
|
|
|
1496
|
-
|
|
1497
|
+
has_kv_cache = exists(kv_cache)
|
|
1498
|
+
use_flex = exists(flex_attention) and tokens.is_cuda and not has_kv_cache # KV cache shape breaks flex attention TODO: Fix
|
|
1497
1499
|
|
|
1498
1500
|
attend_kwargs = dict(use_flex = use_flex, softclamp_value = self.attn_softclamp_value, special_attend_only_itself = self.special_attend_only_itself, device = device)
|
|
1499
1501
|
|
|
@@ -1505,14 +1507,12 @@ class AxialSpaceTimeTransformer(Module):
|
|
|
1505
1507
|
|
|
1506
1508
|
time_attn_kv_caches = []
|
|
1507
1509
|
|
|
1508
|
-
has_kv_cache = exists(kv_cache)
|
|
1509
|
-
|
|
1510
1510
|
|
|
1511
1511
|
if has_kv_cache:
|
|
1512
1512
|
past_tokens, tokens = tokens[:, :-1], tokens[:, -1:]
|
|
1513
1513
|
|
|
1514
1514
|
rotary_seq_len = 1
|
|
1515
|
-
rotary_pos_offset = past_tokens.shape[
|
|
1515
|
+
rotary_pos_offset = past_tokens.shape[1]
|
|
1516
1516
|
else:
|
|
1517
1517
|
rotary_seq_len = time
|
|
1518
1518
|
rotary_pos_offset = 0
|
|
@@ -1687,6 +1687,7 @@ class VideoTokenizer(Module):
|
|
|
1687
1687
|
time_block_every = time_block_every,
|
|
1688
1688
|
num_special_spatial_tokens = num_latent_tokens,
|
|
1689
1689
|
num_residual_streams = num_residual_streams,
|
|
1690
|
+
special_attend_only_itself = True,
|
|
1690
1691
|
final_norm = True
|
|
1691
1692
|
)
|
|
1692
1693
|
|
|
@@ -1847,7 +1848,7 @@ class VideoTokenizer(Module):
|
|
|
1847
1848
|
|
|
1848
1849
|
losses = (recon_loss, lpips_loss)
|
|
1849
1850
|
|
|
1850
|
-
return total_loss, TokenizerLosses(losses)
|
|
1851
|
+
return total_loss, TokenizerLosses(*losses)
|
|
1851
1852
|
|
|
1852
1853
|
# dynamics model, axial space-time transformer
|
|
1853
1854
|
|
|
@@ -2104,7 +2105,7 @@ class DynamicsWorldModel(Module):
|
|
|
2104
2105
|
|
|
2105
2106
|
self.ppo_eps_clip = ppo_eps_clip
|
|
2106
2107
|
self.value_clip = value_clip
|
|
2107
|
-
self.policy_entropy_weight =
|
|
2108
|
+
self.policy_entropy_weight = policy_entropy_weight
|
|
2108
2109
|
|
|
2109
2110
|
# pmpo related
|
|
2110
2111
|
|
|
@@ -2127,7 +2128,7 @@ class DynamicsWorldModel(Module):
|
|
|
2127
2128
|
self.flow_loss_normalizer = LossNormalizer(1)
|
|
2128
2129
|
self.reward_loss_normalizer = LossNormalizer(multi_token_pred_len)
|
|
2129
2130
|
self.discrete_actions_loss_normalizer = LossNormalizer(multi_token_pred_len) if num_discrete_actions > 0 else None
|
|
2130
|
-
self.continuous_actions_loss_normalizer = LossNormalizer(multi_token_pred_len) if
|
|
2131
|
+
self.continuous_actions_loss_normalizer = LossNormalizer(multi_token_pred_len) if num_continuous_actions > 0 else None
|
|
2131
2132
|
|
|
2132
2133
|
self.latent_flow_loss_weight = latent_flow_loss_weight
|
|
2133
2134
|
|
|
@@ -2358,6 +2359,9 @@ class DynamicsWorldModel(Module):
|
|
|
2358
2359
|
elif len(env_step_out) == 4:
|
|
2359
2360
|
next_frame, reward, terminated, truncated = env_step_out
|
|
2360
2361
|
|
|
2362
|
+
elif len(env_step_out) == 5:
|
|
2363
|
+
next_frame, reward, terminated, truncated, info = env_step_out
|
|
2364
|
+
|
|
2361
2365
|
# update episode lens
|
|
2362
2366
|
|
|
2363
2367
|
episode_lens = torch.where(is_terminated, episode_lens, episode_lens + 1)
|
|
@@ -2456,8 +2460,8 @@ class DynamicsWorldModel(Module):
|
|
|
2456
2460
|
if exists(experience.lens):
|
|
2457
2461
|
mask_for_gae = lens_to_mask(experience.lens, time)
|
|
2458
2462
|
|
|
2459
|
-
rewards = rewards.masked_fill(mask_for_gae, 0.)
|
|
2460
|
-
old_values = old_values.masked_fill(mask_for_gae, 0.)
|
|
2463
|
+
rewards = rewards.masked_fill(~mask_for_gae, 0.)
|
|
2464
|
+
old_values = old_values.masked_fill(~mask_for_gae, 0.)
|
|
2461
2465
|
|
|
2462
2466
|
# calculate returns
|
|
2463
2467
|
|
|
@@ -2492,7 +2496,7 @@ class DynamicsWorldModel(Module):
|
|
|
2492
2496
|
|
|
2493
2497
|
# mean, var - todo - handle distributed
|
|
2494
2498
|
|
|
2495
|
-
returns_mean, returns_var =
|
|
2499
|
+
returns_mean, returns_var = returns_for_stats.mean(), returns_for_stats.var()
|
|
2496
2500
|
|
|
2497
2501
|
# ema
|
|
2498
2502
|
|
|
@@ -3085,8 +3089,8 @@ class DynamicsWorldModel(Module):
|
|
|
3085
3089
|
if latents.ndim == 4:
|
|
3086
3090
|
latents = rearrange(latents, 'b t v d -> b t v 1 d') # 1 latent edge case
|
|
3087
3091
|
|
|
3088
|
-
assert latents.shape[-2:] == self.latent_shape
|
|
3089
|
-
assert latents.shape[2] == self.num_video_views
|
|
3092
|
+
assert latents.shape[-2:] == self.latent_shape, f'latents must have shape {self.latent_shape}, got {latents.shape[-2:]}'
|
|
3093
|
+
assert latents.shape[2] == self.num_video_views, f'latents must have {self.num_video_views} views, got {latents.shape[2]}'
|
|
3090
3094
|
|
|
3091
3095
|
# variables
|
|
3092
3096
|
|
|
@@ -3510,7 +3514,7 @@ class DynamicsWorldModel(Module):
|
|
|
3510
3514
|
|
|
3511
3515
|
reward_losses = F.cross_entropy(reward_pred, reward_targets, reduction = 'none')
|
|
3512
3516
|
|
|
3513
|
-
reward_losses = reward_losses.masked_fill(reward_loss_mask, 0.)
|
|
3517
|
+
reward_losses = reward_losses.masked_fill(~reward_loss_mask, 0.)
|
|
3514
3518
|
|
|
3515
3519
|
if is_var_len:
|
|
3516
3520
|
reward_loss = reward_losses[loss_mask_without_last].mean(dim = 0)
|
|
@@ -3554,7 +3558,7 @@ class DynamicsWorldModel(Module):
|
|
|
3554
3558
|
discrete_mask = rearrange(discrete_mask, 'b t mtp -> mtp b t')
|
|
3555
3559
|
|
|
3556
3560
|
if exists(continuous_actions):
|
|
3557
|
-
continuous_action_targets, continuous_mask = create_multi_token_prediction_targets(
|
|
3561
|
+
continuous_action_targets, continuous_mask = create_multi_token_prediction_targets(continuous_actions, self.multi_token_pred_len)
|
|
3558
3562
|
continuous_action_targets = rearrange(continuous_action_targets, 'b t mtp ... -> mtp b t ...')
|
|
3559
3563
|
continuous_mask = rearrange(continuous_mask, 'b t mtp -> mtp b t')
|
|
3560
3564
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: dreamer4
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.4
|
|
4
4
|
Summary: Dreamer 4
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/dreamer4/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/dreamer4
|
|
@@ -57,10 +57,16 @@ Description-Content-Type: text/markdown
|
|
|
57
57
|
|
|
58
58
|
Implementation of Danijar's [latest iteration](https://arxiv.org/abs/2509.24527v1) for his [Dreamer](https://danijar.com/project/dreamer4/) line of work
|
|
59
59
|
|
|
60
|
+
[Discord channel](https://discord.gg/ab4BEk3W) for collaborating with other researchers interested in this work
|
|
61
|
+
|
|
62
|
+
## Appreciation
|
|
63
|
+
|
|
64
|
+
- [@dirkmcpherson](https://github.com/dirkmcpherson) for fixes to typo errors and unpassed arguments!
|
|
65
|
+
|
|
60
66
|
## Install
|
|
61
67
|
|
|
62
68
|
```bash
|
|
63
|
-
$ pip install dreamer4
|
|
69
|
+
$ pip install dreamer4
|
|
64
70
|
```
|
|
65
71
|
|
|
66
72
|
## Usage
|
|
@@ -79,9 +85,16 @@ tokenizer = VideoTokenizer(
|
|
|
79
85
|
image_width = 256
|
|
80
86
|
)
|
|
81
87
|
|
|
88
|
+
video = torch.randn(2, 3, 10, 256, 256)
|
|
89
|
+
|
|
90
|
+
# learn the tokenizer
|
|
91
|
+
|
|
92
|
+
loss = tokenizer(video)
|
|
93
|
+
loss.backward() # ler
|
|
94
|
+
|
|
82
95
|
# dynamics world model
|
|
83
96
|
|
|
84
|
-
|
|
97
|
+
world_model = DynamicsWorldModel(
|
|
85
98
|
dim = 512,
|
|
86
99
|
dim_latent = 32,
|
|
87
100
|
video_tokenizer = tokenizer,
|
|
@@ -97,7 +110,7 @@ rewards = torch.randn(2, 10)
|
|
|
97
110
|
|
|
98
111
|
# learn dynamics / behavior cloned model
|
|
99
112
|
|
|
100
|
-
loss =
|
|
113
|
+
loss = world_model(
|
|
101
114
|
video = video,
|
|
102
115
|
rewards = rewards,
|
|
103
116
|
discrete_actions = discrete_actions
|
|
@@ -109,7 +122,7 @@ loss.backward()
|
|
|
109
122
|
|
|
110
123
|
# then generate dreams
|
|
111
124
|
|
|
112
|
-
dreams =
|
|
125
|
+
dreams = world_model.generate(
|
|
113
126
|
10,
|
|
114
127
|
batch_size = 2,
|
|
115
128
|
return_decoded_video = True,
|
|
@@ -118,7 +131,19 @@ dreams = dynamics.generate(
|
|
|
118
131
|
|
|
119
132
|
# learn from the dreams
|
|
120
133
|
|
|
121
|
-
actor_loss, critic_loss =
|
|
134
|
+
actor_loss, critic_loss = world_model.learn_from_experience(dreams)
|
|
135
|
+
|
|
136
|
+
(actor_loss + critic_loss).backward()
|
|
137
|
+
|
|
138
|
+
# learn from environment
|
|
139
|
+
|
|
140
|
+
from dreamer4.mocks import MockEnv
|
|
141
|
+
|
|
142
|
+
mock_env = MockEnv((256, 256), vectorized = True, num_envs = 4)
|
|
143
|
+
|
|
144
|
+
experience = world_model.interact_with_env(mock_env, max_timesteps = 8, env_is_vectorized = True)
|
|
145
|
+
|
|
146
|
+
actor_loss, critic_loss = world_model.learn_from_experience(experience)
|
|
122
147
|
|
|
123
148
|
(actor_loss + critic_loss).backward()
|
|
124
149
|
```
|
|
@@ -137,4 +162,4 @@ actor_loss, critic_loss = dynamics.learn_from_experience(dreams)
|
|
|
137
162
|
}
|
|
138
163
|
```
|
|
139
164
|
|
|
140
|
-
*the conquest of nature is to be achieved through number and measure
|
|
165
|
+
*the conquest of nature is to be achieved through number and measure - angels to Descartes in a dream*
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
dreamer4/__init__.py,sha256=Jssh1obzDRtTfBLZl36kXge1cIQlMjf_8DyjPulvKSk,183
|
|
2
|
+
dreamer4/dreamer4.py,sha256=ghestMgz7B1oEqBRR0XkkdWe0kkh7bshhzmi6-n-XIs,120790
|
|
3
|
+
dreamer4/mocks.py,sha256=TfqOB_Gq6N_GggBYwa6ZAJQx38ntlYbXZe23Ne4jshw,2502
|
|
4
|
+
dreamer4/trainers.py,sha256=JsnJwQJcbI_75KBTNddG6b7QVkO6LD1N_HQiVe-VnCM,15087
|
|
5
|
+
dreamer4-0.1.4.dist-info/METADATA,sha256=GkzuqKtNJJCSh5FycWJOr49253_w926biJkSz9ic4TQ,4941
|
|
6
|
+
dreamer4-0.1.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
+
dreamer4-0.1.4.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
+
dreamer4-0.1.4.dist-info/RECORD,,
|
dreamer4-0.1.0.dist-info/RECORD
DELETED
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
dreamer4/__init__.py,sha256=Jssh1obzDRtTfBLZl36kXge1cIQlMjf_8DyjPulvKSk,183
|
|
2
|
-
dreamer4/dreamer4.py,sha256=YB724hMjBYDNhApo2x_52oXIeH5GGQo8Q2pB2lkCq_s,120297
|
|
3
|
-
dreamer4/mocks.py,sha256=TfqOB_Gq6N_GggBYwa6ZAJQx38ntlYbXZe23Ne4jshw,2502
|
|
4
|
-
dreamer4/trainers.py,sha256=JsnJwQJcbI_75KBTNddG6b7QVkO6LD1N_HQiVe-VnCM,15087
|
|
5
|
-
dreamer4-0.1.0.dist-info/METADATA,sha256=kDq66Il_WDNKR2NP9wrVu3fMUIVW-pWySq3CP2ANZ2s,4273
|
|
6
|
-
dreamer4-0.1.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
-
dreamer4-0.1.0.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
-
dreamer4-0.1.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|