dreamer4 0.0.102__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dreamer4 might be problematic. Click here for more details.
- dreamer4/dreamer4.py +47 -15
- {dreamer4-0.0.102.dist-info → dreamer4-0.1.1.dist-info}/METADATA +69 -3
- dreamer4-0.1.1.dist-info/RECORD +8 -0
- dreamer4-0.0.102.dist-info/RECORD +0 -8
- {dreamer4-0.0.102.dist-info → dreamer4-0.1.1.dist-info}/WHEEL +0 -0
- {dreamer4-0.0.102.dist-info → dreamer4-0.1.1.dist-info}/licenses/LICENSE +0 -0
dreamer4/dreamer4.py
CHANGED
|
@@ -1331,6 +1331,12 @@ class Attention(Module):
|
|
|
1331
1331
|
q = self.q_heads_rmsnorm(q)
|
|
1332
1332
|
k = self.k_heads_rmsnorm(k)
|
|
1333
1333
|
|
|
1334
|
+
# rotary
|
|
1335
|
+
|
|
1336
|
+
if exists(rotary_pos_emb):
|
|
1337
|
+
q = apply_rotations(rotary_pos_emb, q)
|
|
1338
|
+
k = apply_rotations(rotary_pos_emb, k)
|
|
1339
|
+
|
|
1334
1340
|
# caching
|
|
1335
1341
|
|
|
1336
1342
|
if exists(kv_cache):
|
|
@@ -1338,12 +1344,6 @@ class Attention(Module):
|
|
|
1338
1344
|
k = cat((ck, k), dim = -2)
|
|
1339
1345
|
v = cat((cv, v), dim = -2)
|
|
1340
1346
|
|
|
1341
|
-
# rotary
|
|
1342
|
-
|
|
1343
|
-
if exists(rotary_pos_emb):
|
|
1344
|
-
q = apply_rotations(rotary_pos_emb, q)
|
|
1345
|
-
k = apply_rotations(rotary_pos_emb, k)
|
|
1346
|
-
|
|
1347
1347
|
# attention
|
|
1348
1348
|
|
|
1349
1349
|
attend_fn = default(attend_fn, naive_attend)
|
|
@@ -1507,12 +1507,11 @@ class AxialSpaceTimeTransformer(Module):
|
|
|
1507
1507
|
|
|
1508
1508
|
has_kv_cache = exists(kv_cache)
|
|
1509
1509
|
|
|
1510
|
-
|
|
1511
1510
|
if has_kv_cache:
|
|
1512
1511
|
past_tokens, tokens = tokens[:, :-1], tokens[:, -1:]
|
|
1513
1512
|
|
|
1514
1513
|
rotary_seq_len = 1
|
|
1515
|
-
rotary_pos_offset = past_tokens.shape[
|
|
1514
|
+
rotary_pos_offset = past_tokens.shape[1]
|
|
1516
1515
|
else:
|
|
1517
1516
|
rotary_seq_len = time
|
|
1518
1517
|
rotary_pos_offset = 0
|
|
@@ -1687,6 +1686,7 @@ class VideoTokenizer(Module):
|
|
|
1687
1686
|
time_block_every = time_block_every,
|
|
1688
1687
|
num_special_spatial_tokens = num_latent_tokens,
|
|
1689
1688
|
num_residual_streams = num_residual_streams,
|
|
1689
|
+
special_attend_only_itself = True,
|
|
1690
1690
|
final_norm = True
|
|
1691
1691
|
)
|
|
1692
1692
|
|
|
@@ -2429,6 +2429,7 @@ class DynamicsWorldModel(Module):
|
|
|
2429
2429
|
normalize_advantages = None,
|
|
2430
2430
|
eps = 1e-6
|
|
2431
2431
|
):
|
|
2432
|
+
assert isinstance(experience, Experience)
|
|
2432
2433
|
|
|
2433
2434
|
latents = experience.latents
|
|
2434
2435
|
actions = experience.actions
|
|
@@ -2441,7 +2442,7 @@ class DynamicsWorldModel(Module):
|
|
|
2441
2442
|
step_size = experience.step_size
|
|
2442
2443
|
agent_index = experience.agent_index
|
|
2443
2444
|
|
|
2444
|
-
assert all([*map(exists, (old_log_probs, actions, old_values, rewards, step_size))]), 'the generations need to contain the log probs, values, and rewards for policy optimization'
|
|
2445
|
+
assert all([*map(exists, (old_log_probs, actions, old_values, rewards, step_size))]), 'the generations need to contain the log probs, values, and rewards for policy optimization - world_model.generate(..., return_log_probs_and_values = True)'
|
|
2445
2446
|
|
|
2446
2447
|
batch, time = latents.shape[0], latents.shape[1]
|
|
2447
2448
|
|
|
@@ -2455,8 +2456,8 @@ class DynamicsWorldModel(Module):
|
|
|
2455
2456
|
if exists(experience.lens):
|
|
2456
2457
|
mask_for_gae = lens_to_mask(experience.lens, time)
|
|
2457
2458
|
|
|
2458
|
-
rewards = rewards.masked_fill(mask_for_gae, 0.)
|
|
2459
|
-
old_values = old_values.masked_fill(mask_for_gae, 0.)
|
|
2459
|
+
rewards = rewards.masked_fill(~mask_for_gae, 0.)
|
|
2460
|
+
old_values = old_values.masked_fill(~mask_for_gae, 0.)
|
|
2460
2461
|
|
|
2461
2462
|
# calculate returns
|
|
2462
2463
|
|
|
@@ -2491,7 +2492,7 @@ class DynamicsWorldModel(Module):
|
|
|
2491
2492
|
|
|
2492
2493
|
# mean, var - todo - handle distributed
|
|
2493
2494
|
|
|
2494
|
-
returns_mean, returns_var =
|
|
2495
|
+
returns_mean, returns_var = returns_for_stats.mean(), returns_for_stats.var()
|
|
2495
2496
|
|
|
2496
2497
|
# ema
|
|
2497
2498
|
|
|
@@ -2694,12 +2695,22 @@ class DynamicsWorldModel(Module):
|
|
|
2694
2695
|
return_rewards_per_frame = False,
|
|
2695
2696
|
return_agent_actions = False,
|
|
2696
2697
|
return_log_probs_and_values = False,
|
|
2698
|
+
return_for_policy_optimization = False,
|
|
2697
2699
|
return_time_kv_cache = False,
|
|
2698
2700
|
store_agent_embed = True,
|
|
2699
2701
|
store_old_action_unembeds = True
|
|
2700
2702
|
|
|
2701
2703
|
): # (b t n d) | (b c t h w)
|
|
2702
2704
|
|
|
2705
|
+
# handy flag for returning generations for rl
|
|
2706
|
+
|
|
2707
|
+
if return_for_policy_optimization:
|
|
2708
|
+
return_agent_actions |= True
|
|
2709
|
+
return_log_probs_and_values |= True
|
|
2710
|
+
return_rewards_per_frame |= True
|
|
2711
|
+
|
|
2712
|
+
# more variables
|
|
2713
|
+
|
|
2703
2714
|
has_proprio = self.has_proprio
|
|
2704
2715
|
was_training = self.training
|
|
2705
2716
|
self.eval()
|
|
@@ -2769,6 +2780,19 @@ class DynamicsWorldModel(Module):
|
|
|
2769
2780
|
|
|
2770
2781
|
curr_time_steps = latents.shape[1]
|
|
2771
2782
|
|
|
2783
|
+
# determine whether to take an extra step if
|
|
2784
|
+
# (1) using time kv cache
|
|
2785
|
+
# (2) decoding anything off agent embedding (rewards, actions, etc)
|
|
2786
|
+
|
|
2787
|
+
take_extra_step = (
|
|
2788
|
+
use_time_kv_cache or
|
|
2789
|
+
return_rewards_per_frame or
|
|
2790
|
+
store_agent_embed or
|
|
2791
|
+
return_agent_actions
|
|
2792
|
+
)
|
|
2793
|
+
|
|
2794
|
+
# prepare noised latent / proprio inputs
|
|
2795
|
+
|
|
2772
2796
|
noised_latent = randn((batch_size, 1, self.num_video_views, *latent_shape), device = self.device)
|
|
2773
2797
|
|
|
2774
2798
|
noised_proprio = None
|
|
@@ -2776,7 +2800,10 @@ class DynamicsWorldModel(Module):
|
|
|
2776
2800
|
if has_proprio:
|
|
2777
2801
|
noised_proprio = randn((batch_size, 1, self.dim_proprio), device = self.device)
|
|
2778
2802
|
|
|
2779
|
-
|
|
2803
|
+
# denoising steps
|
|
2804
|
+
|
|
2805
|
+
for step in range(num_steps + int(take_extra_step)):
|
|
2806
|
+
|
|
2780
2807
|
is_last_step = (step + 1) == num_steps
|
|
2781
2808
|
|
|
2782
2809
|
signal_levels = full((batch_size, 1), step * step_size, dtype = torch.long, device = self.device)
|
|
@@ -2819,6 +2846,11 @@ class DynamicsWorldModel(Module):
|
|
|
2819
2846
|
if use_time_kv_cache and is_last_step:
|
|
2820
2847
|
time_kv_cache = next_time_kv_cache
|
|
2821
2848
|
|
|
2849
|
+
# early break if taking an extra step for agent embedding off cleaned latents for decoding
|
|
2850
|
+
|
|
2851
|
+
if take_extra_step and is_last_step:
|
|
2852
|
+
break
|
|
2853
|
+
|
|
2822
2854
|
# maybe proprio
|
|
2823
2855
|
|
|
2824
2856
|
if has_proprio:
|
|
@@ -3021,7 +3053,7 @@ class DynamicsWorldModel(Module):
|
|
|
3021
3053
|
latent_is_noised = False,
|
|
3022
3054
|
return_all_losses = False,
|
|
3023
3055
|
return_intermediates = False,
|
|
3024
|
-
add_autoregressive_action_loss =
|
|
3056
|
+
add_autoregressive_action_loss = True,
|
|
3025
3057
|
update_loss_ema = None,
|
|
3026
3058
|
latent_has_view_dim = False
|
|
3027
3059
|
):
|
|
@@ -3478,7 +3510,7 @@ class DynamicsWorldModel(Module):
|
|
|
3478
3510
|
|
|
3479
3511
|
reward_losses = F.cross_entropy(reward_pred, reward_targets, reduction = 'none')
|
|
3480
3512
|
|
|
3481
|
-
reward_losses = reward_losses.masked_fill(reward_loss_mask, 0.)
|
|
3513
|
+
reward_losses = reward_losses.masked_fill(~reward_loss_mask, 0.)
|
|
3482
3514
|
|
|
3483
3515
|
if is_var_len:
|
|
3484
3516
|
reward_loss = reward_losses[loss_mask_without_last].mean(dim = 0)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: dreamer4
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.1.1
|
|
4
4
|
Summary: Dreamer 4
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/dreamer4/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/dreamer4
|
|
@@ -53,11 +53,75 @@ Description-Content-Type: text/markdown
|
|
|
53
53
|
|
|
54
54
|
<img src="./dreamer4-fig2.png" width="400px"></img>
|
|
55
55
|
|
|
56
|
-
## Dreamer 4
|
|
56
|
+
## Dreamer 4
|
|
57
57
|
|
|
58
58
|
Implementation of Danijar's [latest iteration](https://arxiv.org/abs/2509.24527v1) for his [Dreamer](https://danijar.com/project/dreamer4/) line of work
|
|
59
59
|
|
|
60
|
-
|
|
60
|
+
## Install
|
|
61
|
+
|
|
62
|
+
```bash
|
|
63
|
+
$ pip install dreamer4
|
|
64
|
+
```
|
|
65
|
+
|
|
66
|
+
## Usage
|
|
67
|
+
|
|
68
|
+
```python
|
|
69
|
+
import torch
|
|
70
|
+
from dreamer4 import VideoTokenizer, DynamicsWorldModel
|
|
71
|
+
|
|
72
|
+
# video tokenizer, learned through MAE + lpips
|
|
73
|
+
|
|
74
|
+
tokenizer = VideoTokenizer(
|
|
75
|
+
dim = 512,
|
|
76
|
+
dim_latent = 32,
|
|
77
|
+
patch_size = 32,
|
|
78
|
+
image_height = 256,
|
|
79
|
+
image_width = 256
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
# dynamics world model
|
|
83
|
+
|
|
84
|
+
dynamics = DynamicsWorldModel(
|
|
85
|
+
dim = 512,
|
|
86
|
+
dim_latent = 32,
|
|
87
|
+
video_tokenizer = tokenizer,
|
|
88
|
+
num_discrete_actions = 4,
|
|
89
|
+
num_residual_streams = 1
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
# state, action, rewards
|
|
93
|
+
|
|
94
|
+
video = torch.randn(2, 3, 10, 256, 256)
|
|
95
|
+
discrete_actions = torch.randint(0, 4, (2, 10, 1))
|
|
96
|
+
rewards = torch.randn(2, 10)
|
|
97
|
+
|
|
98
|
+
# learn dynamics / behavior cloned model
|
|
99
|
+
|
|
100
|
+
loss = dynamics(
|
|
101
|
+
video = video,
|
|
102
|
+
rewards = rewards,
|
|
103
|
+
discrete_actions = discrete_actions
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
loss.backward()
|
|
107
|
+
|
|
108
|
+
# do the above with much data
|
|
109
|
+
|
|
110
|
+
# then generate dreams
|
|
111
|
+
|
|
112
|
+
dreams = dynamics.generate(
|
|
113
|
+
10,
|
|
114
|
+
batch_size = 2,
|
|
115
|
+
return_decoded_video = True,
|
|
116
|
+
return_for_policy_optimization = True
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
# learn from the dreams
|
|
120
|
+
|
|
121
|
+
actor_loss, critic_loss = dynamics.learn_from_experience(dreams)
|
|
122
|
+
|
|
123
|
+
(actor_loss + critic_loss).backward()
|
|
124
|
+
```
|
|
61
125
|
|
|
62
126
|
## Citation
|
|
63
127
|
|
|
@@ -72,3 +136,5 @@ Implementation of Danijar's [latest iteration](https://arxiv.org/abs/2509.24527v
|
|
|
72
136
|
url = {https://arxiv.org/abs/2509.24527},
|
|
73
137
|
}
|
|
74
138
|
```
|
|
139
|
+
|
|
140
|
+
*the conquest of nature is to be achieved through number and measure - angels to Descartes in a dream*
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
dreamer4/__init__.py,sha256=Jssh1obzDRtTfBLZl36kXge1cIQlMjf_8DyjPulvKSk,183
|
|
2
|
+
dreamer4/dreamer4.py,sha256=_gWp08k7tf2VCUv7uqkXKZQugnqJqXPb1-o7_34SA9c,120365
|
|
3
|
+
dreamer4/mocks.py,sha256=TfqOB_Gq6N_GggBYwa6ZAJQx38ntlYbXZe23Ne4jshw,2502
|
|
4
|
+
dreamer4/trainers.py,sha256=JsnJwQJcbI_75KBTNddG6b7QVkO6LD1N_HQiVe-VnCM,15087
|
|
5
|
+
dreamer4-0.1.1.dist-info/METADATA,sha256=2zkBv1BHvGpb6onAnEFsKnPK2KD-0vH8K1nFDBVlpyU,4247
|
|
6
|
+
dreamer4-0.1.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
+
dreamer4-0.1.1.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
+
dreamer4-0.1.1.dist-info/RECORD,,
|
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
dreamer4/__init__.py,sha256=Jssh1obzDRtTfBLZl36kXge1cIQlMjf_8DyjPulvKSk,183
|
|
2
|
-
dreamer4/dreamer4.py,sha256=3qeVN3qdvx7iPxA0OBXw_yy5Re6rX6FIKITH9bp6RBs,119202
|
|
3
|
-
dreamer4/mocks.py,sha256=TfqOB_Gq6N_GggBYwa6ZAJQx38ntlYbXZe23Ne4jshw,2502
|
|
4
|
-
dreamer4/trainers.py,sha256=JsnJwQJcbI_75KBTNddG6b7QVkO6LD1N_HQiVe-VnCM,15087
|
|
5
|
-
dreamer4-0.0.102.dist-info/METADATA,sha256=xxVL1sFimb0azSD5sDOEzugY7rBT6oDek4YdiIS8m18,3066
|
|
6
|
-
dreamer4-0.0.102.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
-
dreamer4-0.0.102.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
-
dreamer4-0.0.102.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|