dreamer4 0.0.71__tar.gz → 0.0.73__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dreamer4 might be problematic. Click here for more details.
- {dreamer4-0.0.71 → dreamer4-0.0.73}/PKG-INFO +1 -1
- {dreamer4-0.0.71 → dreamer4-0.0.73}/dreamer4/dreamer4.py +42 -4
- {dreamer4-0.0.71 → dreamer4-0.0.73}/pyproject.toml +1 -1
- {dreamer4-0.0.71 → dreamer4-0.0.73}/tests/test_dreamer.py +14 -4
- {dreamer4-0.0.71 → dreamer4-0.0.73}/.github/workflows/python-publish.yml +0 -0
- {dreamer4-0.0.71 → dreamer4-0.0.73}/.github/workflows/test.yml +0 -0
- {dreamer4-0.0.71 → dreamer4-0.0.73}/.gitignore +0 -0
- {dreamer4-0.0.71 → dreamer4-0.0.73}/LICENSE +0 -0
- {dreamer4-0.0.71 → dreamer4-0.0.73}/README.md +0 -0
- {dreamer4-0.0.71 → dreamer4-0.0.73}/dreamer4/__init__.py +0 -0
- {dreamer4-0.0.71 → dreamer4-0.0.73}/dreamer4/mocks.py +0 -0
- {dreamer4-0.0.71 → dreamer4-0.0.73}/dreamer4/trainers.py +0 -0
- {dreamer4-0.0.71 → dreamer4-0.0.73}/dreamer4-fig2.png +0 -0
|
@@ -82,6 +82,7 @@ class Experience:
|
|
|
82
82
|
log_probs: tuple[Tensor, Tensor] | None = None
|
|
83
83
|
values: Tensor | None = None
|
|
84
84
|
step_size: int | None = None
|
|
85
|
+
lens: Tensor | None = None,
|
|
85
86
|
agent_index: int = 0
|
|
86
87
|
is_from_world_model: bool = True
|
|
87
88
|
|
|
@@ -155,6 +156,15 @@ def is_power_two(num):
|
|
|
155
156
|
def is_empty(t):
|
|
156
157
|
return t.numel() == 0
|
|
157
158
|
|
|
159
|
+
def lens_to_mask(t, max_len = None):
|
|
160
|
+
if not exists(max_len):
|
|
161
|
+
max_len = t.amax().item()
|
|
162
|
+
|
|
163
|
+
device = t.device
|
|
164
|
+
seq = torch.arange(max_len, device = device)
|
|
165
|
+
|
|
166
|
+
return einx.less('j, i -> i j', seq, t)
|
|
167
|
+
|
|
158
168
|
def log(t, eps = 1e-20):
|
|
159
169
|
return t.clamp(min = eps).log()
|
|
160
170
|
|
|
@@ -2551,12 +2561,16 @@ class DynamicsWorldModel(Module):
|
|
|
2551
2561
|
|
|
2552
2562
|
# returning agent actions, rewards, and log probs + values for policy optimization
|
|
2553
2563
|
|
|
2564
|
+
batch, device = latents.shape[0], latents.device
|
|
2565
|
+
experience_lens = torch.full((batch,), time_steps, device = device)
|
|
2566
|
+
|
|
2554
2567
|
gen = Experience(
|
|
2555
2568
|
latents = latents,
|
|
2556
2569
|
video = video,
|
|
2557
2570
|
proprio = proprio if has_proprio else None,
|
|
2558
2571
|
step_size = step_size,
|
|
2559
2572
|
agent_index = agent_index,
|
|
2573
|
+
lens = experience_lens,
|
|
2560
2574
|
is_from_world_model = True
|
|
2561
2575
|
)
|
|
2562
2576
|
|
|
@@ -2581,6 +2595,7 @@ class DynamicsWorldModel(Module):
|
|
|
2581
2595
|
*,
|
|
2582
2596
|
video = None, # (b v? c t vh vw)
|
|
2583
2597
|
latents = None, # (b t v? n d) | (b t v? d)
|
|
2598
|
+
lens = None, # (b)
|
|
2584
2599
|
signal_levels = None, # () | (b) | (b t)
|
|
2585
2600
|
step_sizes = None, # () | (b)
|
|
2586
2601
|
step_sizes_log2 = None, # () | (b)
|
|
@@ -3014,7 +3029,19 @@ class DynamicsWorldModel(Module):
|
|
|
3014
3029
|
|
|
3015
3030
|
flow_losses = flow_losses * loss_weight
|
|
3016
3031
|
|
|
3017
|
-
|
|
3032
|
+
# handle variable lengths if needed
|
|
3033
|
+
|
|
3034
|
+
is_var_len = exists(lens)
|
|
3035
|
+
|
|
3036
|
+
if is_var_len:
|
|
3037
|
+
|
|
3038
|
+
loss_mask = lens_to_mask(lens, time)
|
|
3039
|
+
loss_mask_without_last = loss_mask[:, :-1]
|
|
3040
|
+
|
|
3041
|
+
flow_loss = flow_losses[loss_mask].mean()
|
|
3042
|
+
|
|
3043
|
+
else:
|
|
3044
|
+
flow_loss = flow_losses.mean()
|
|
3018
3045
|
|
|
3019
3046
|
# now take care of the agent token losses
|
|
3020
3047
|
|
|
@@ -3037,7 +3064,10 @@ class DynamicsWorldModel(Module):
|
|
|
3037
3064
|
|
|
3038
3065
|
reward_losses = reward_losses.masked_fill(reward_loss_mask, 0.)
|
|
3039
3066
|
|
|
3040
|
-
|
|
3067
|
+
if is_var_len:
|
|
3068
|
+
reward_loss = reward_losses[loss_mask_without_last].mean(dim = 0)
|
|
3069
|
+
else:
|
|
3070
|
+
reward_loss = reduce(reward_losses, '... mtp -> mtp', 'mean') # they sum across the prediction steps (mtp dimension) - eq(9)
|
|
3041
3071
|
|
|
3042
3072
|
# maybe autoregressive action loss
|
|
3043
3073
|
|
|
@@ -3080,12 +3110,20 @@ class DynamicsWorldModel(Module):
|
|
|
3080
3110
|
if exists(discrete_log_probs):
|
|
3081
3111
|
discrete_log_probs = discrete_log_probs.masked_fill(~discrete_mask[..., None], 0.)
|
|
3082
3112
|
|
|
3083
|
-
|
|
3113
|
+
if is_var_len:
|
|
3114
|
+
discrete_action_losses = rearrange(-discrete_log_probs, 'mtp b t na -> b t na mtp')
|
|
3115
|
+
discrete_action_loss = reduce(discrete_action_losses[loss_mask_without_last], '... mtp -> mtp', 'mean')
|
|
3116
|
+
else:
|
|
3117
|
+
discrete_action_loss = reduce(-discrete_log_probs, 'mtp b t na -> mtp', 'mean')
|
|
3084
3118
|
|
|
3085
3119
|
if exists(continuous_log_probs):
|
|
3086
3120
|
continuous_log_probs = continuous_log_probs.masked_fill(~continuous_mask[..., None], 0.)
|
|
3087
3121
|
|
|
3088
|
-
|
|
3122
|
+
if is_var_len:
|
|
3123
|
+
continuous_action_losses = rearrange(-continuous_log_probs, 'mtp b t na -> b t na mtp')
|
|
3124
|
+
continuous_action_loss = reduce(continuous_action_losses[loss_mask_without_last], '... mtp -> mtp', 'mean')
|
|
3125
|
+
else:
|
|
3126
|
+
continuous_action_loss = reduce(-continuous_log_probs, 'mtp b t na -> mtp', 'mean')
|
|
3089
3127
|
|
|
3090
3128
|
# handle loss normalization
|
|
3091
3129
|
|
|
@@ -16,6 +16,7 @@ def exists(v):
|
|
|
16
16
|
@param('num_residual_streams', (1, 4))
|
|
17
17
|
@param('add_reward_embed_to_agent_token', (False, True))
|
|
18
18
|
@param('use_time_kv_cache', (False, True))
|
|
19
|
+
@param('var_len', (False, True))
|
|
19
20
|
def test_e2e(
|
|
20
21
|
pred_orig_latent,
|
|
21
22
|
grouped_query_attn,
|
|
@@ -27,7 +28,8 @@ def test_e2e(
|
|
|
27
28
|
condition_on_actions,
|
|
28
29
|
num_residual_streams,
|
|
29
30
|
add_reward_embed_to_agent_token,
|
|
30
|
-
use_time_kv_cache
|
|
31
|
+
use_time_kv_cache,
|
|
32
|
+
var_len
|
|
31
33
|
):
|
|
32
34
|
from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel
|
|
33
35
|
|
|
@@ -95,8 +97,13 @@ def test_e2e(
|
|
|
95
97
|
if condition_on_actions:
|
|
96
98
|
actions = torch.randint(0, 4, (2, 3, 1))
|
|
97
99
|
|
|
100
|
+
lens = None
|
|
101
|
+
if var_len:
|
|
102
|
+
lens = torch.randint(1, 4, (2,))
|
|
103
|
+
|
|
98
104
|
flow_loss = dynamics(
|
|
99
105
|
**dynamics_input,
|
|
106
|
+
lens = lens,
|
|
100
107
|
tasks = tasks,
|
|
101
108
|
signal_levels = signal_levels,
|
|
102
109
|
step_sizes_log2 = step_sizes_log2,
|
|
@@ -702,10 +709,11 @@ def test_proprioception(
|
|
|
702
709
|
)
|
|
703
710
|
|
|
704
711
|
if num_video_views > 1:
|
|
705
|
-
|
|
712
|
+
video_shape = (2, num_video_views, 3, 10, 256, 256)
|
|
706
713
|
else:
|
|
707
|
-
|
|
714
|
+
video_shape = (2, 3, 10, 256, 256)
|
|
708
715
|
|
|
716
|
+
video = torch.randn(*video_shape)
|
|
709
717
|
rewards = torch.randn(2, 10)
|
|
710
718
|
proprio = torch.randn(2, 10, 21)
|
|
711
719
|
discrete_actions = torch.randint(0, 4, (2, 10, 1))
|
|
@@ -722,8 +730,10 @@ def test_proprioception(
|
|
|
722
730
|
loss.backward()
|
|
723
731
|
|
|
724
732
|
generations = dynamics.generate(
|
|
725
|
-
|
|
733
|
+
10,
|
|
726
734
|
batch_size = 2,
|
|
735
|
+
return_decoded_video = True
|
|
727
736
|
)
|
|
728
737
|
|
|
729
738
|
assert exists(generations.proprio)
|
|
739
|
+
assert generations.video.shape == video_shape
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|