dreamer4 0.0.70__tar.gz → 0.0.72__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dreamer4 might be problematic. Click here for more details.
- {dreamer4-0.0.70 → dreamer4-0.0.72}/PKG-INFO +1 -1
- {dreamer4-0.0.70 → dreamer4-0.0.72}/dreamer4/dreamer4.py +98 -20
- {dreamer4-0.0.70 → dreamer4-0.0.72}/pyproject.toml +1 -1
- {dreamer4-0.0.70 → dreamer4-0.0.72}/tests/test_dreamer.py +22 -4
- {dreamer4-0.0.70 → dreamer4-0.0.72}/.github/workflows/python-publish.yml +0 -0
- {dreamer4-0.0.70 → dreamer4-0.0.72}/.github/workflows/test.yml +0 -0
- {dreamer4-0.0.70 → dreamer4-0.0.72}/.gitignore +0 -0
- {dreamer4-0.0.70 → dreamer4-0.0.72}/LICENSE +0 -0
- {dreamer4-0.0.70 → dreamer4-0.0.72}/README.md +0 -0
- {dreamer4-0.0.70 → dreamer4-0.0.72}/dreamer4/__init__.py +0 -0
- {dreamer4-0.0.70 → dreamer4-0.0.72}/dreamer4/mocks.py +0 -0
- {dreamer4-0.0.70 → dreamer4-0.0.72}/dreamer4/trainers.py +0 -0
- {dreamer4-0.0.70 → dreamer4-0.0.72}/dreamer4-fig2.png +0 -0
|
@@ -45,6 +45,7 @@ from assoc_scan import AssocScan
|
|
|
45
45
|
# vc - video channels
|
|
46
46
|
# vh, vw - video height and width
|
|
47
47
|
# mtp - multi token prediction length
|
|
48
|
+
# v - video viewpoints
|
|
48
49
|
|
|
49
50
|
import einx
|
|
50
51
|
from einx import add, multiply
|
|
@@ -154,6 +155,15 @@ def is_power_two(num):
|
|
|
154
155
|
def is_empty(t):
|
|
155
156
|
return t.numel() == 0
|
|
156
157
|
|
|
158
|
+
def lens_to_mask(t, max_len = None):
|
|
159
|
+
if not exists(max_len):
|
|
160
|
+
max_len = t.amax().item()
|
|
161
|
+
|
|
162
|
+
device = t.device
|
|
163
|
+
seq = torch.arange(max_len, device = device)
|
|
164
|
+
|
|
165
|
+
return einx.less('j, i -> i j', seq, t)
|
|
166
|
+
|
|
157
167
|
def log(t, eps = 1e-20):
|
|
158
168
|
return t.clamp(min = eps).log()
|
|
159
169
|
|
|
@@ -1735,6 +1745,7 @@ class DynamicsWorldModel(Module):
|
|
|
1735
1745
|
num_latent_tokens = None,
|
|
1736
1746
|
num_agents = 1,
|
|
1737
1747
|
num_tasks = 0,
|
|
1748
|
+
num_video_views = 1,
|
|
1738
1749
|
dim_proprio = None,
|
|
1739
1750
|
reward_encoder_kwargs: dict = dict(),
|
|
1740
1751
|
depth = 4,
|
|
@@ -1800,7 +1811,7 @@ class DynamicsWorldModel(Module):
|
|
|
1800
1811
|
)
|
|
1801
1812
|
|
|
1802
1813
|
self.to_latent_pred = Sequential(
|
|
1803
|
-
Reduce('b t n s d -> b t n d', 'mean'),
|
|
1814
|
+
Reduce('b t v n s d -> b t v n d', 'mean'),
|
|
1804
1815
|
RMSNorm(dim),
|
|
1805
1816
|
LinearNoBias(dim, dim_latent)
|
|
1806
1817
|
)
|
|
@@ -1810,17 +1821,27 @@ class DynamicsWorldModel(Module):
|
|
|
1810
1821
|
latent_tokens_to_space = num_latent_tokens // num_spatial_tokens
|
|
1811
1822
|
|
|
1812
1823
|
self.latents_to_spatial_tokens = Sequential(
|
|
1813
|
-
Rearrange('
|
|
1824
|
+
Rearrange('... n d -> ... (n d)'),
|
|
1814
1825
|
Linear(num_latent_tokens * dim_latent, dim * num_spatial_tokens),
|
|
1815
|
-
Rearrange('
|
|
1826
|
+
Rearrange('... (s d) -> ... s d', s = num_spatial_tokens)
|
|
1816
1827
|
)
|
|
1817
1828
|
|
|
1818
1829
|
self.to_latent_pred = Sequential(
|
|
1819
1830
|
RMSNorm(dim),
|
|
1820
1831
|
LinearNoBias(dim, dim_latent * latent_tokens_to_space),
|
|
1821
|
-
Rearrange('b t s (n d) -> b t (s n) d', n = latent_tokens_to_space)
|
|
1832
|
+
Rearrange('b t v s (n d) -> b t v (s n) d', n = latent_tokens_to_space)
|
|
1822
1833
|
)
|
|
1823
1834
|
|
|
1835
|
+
# number of video views, for robotics, which could have third person + wrist camera at least
|
|
1836
|
+
|
|
1837
|
+
assert num_video_views >= 1
|
|
1838
|
+
self.video_has_multi_view = num_video_views > 1
|
|
1839
|
+
|
|
1840
|
+
self.num_video_views = num_video_views
|
|
1841
|
+
|
|
1842
|
+
if self.video_has_multi_view:
|
|
1843
|
+
self.view_emb = nn.Parameter(torch.randn(num_video_views, dim) * 1e-2)
|
|
1844
|
+
|
|
1824
1845
|
# proprioception
|
|
1825
1846
|
|
|
1826
1847
|
self.has_proprio = exists(dim_proprio)
|
|
@@ -2318,7 +2339,7 @@ class DynamicsWorldModel(Module):
|
|
|
2318
2339
|
# denoising
|
|
2319
2340
|
# teacher forcing to start with
|
|
2320
2341
|
|
|
2321
|
-
latents = empty((batch_size, 0, *latent_shape), device = self.device)
|
|
2342
|
+
latents = empty((batch_size, 0, self.num_video_views, *latent_shape), device = self.device)
|
|
2322
2343
|
|
|
2323
2344
|
past_latents_context_noise = latents.clone()
|
|
2324
2345
|
|
|
@@ -2354,7 +2375,7 @@ class DynamicsWorldModel(Module):
|
|
|
2354
2375
|
|
|
2355
2376
|
curr_time_steps = latents.shape[1]
|
|
2356
2377
|
|
|
2357
|
-
noised_latent = randn((batch_size, 1, *latent_shape), device = self.device)
|
|
2378
|
+
noised_latent = randn((batch_size, 1, self.num_video_views, *latent_shape), device = self.device)
|
|
2358
2379
|
|
|
2359
2380
|
noised_proprio = None
|
|
2360
2381
|
|
|
@@ -2365,12 +2386,12 @@ class DynamicsWorldModel(Module):
|
|
|
2365
2386
|
is_last_step = (step + 1) == num_steps
|
|
2366
2387
|
|
|
2367
2388
|
signal_levels = full((batch_size, 1), step * step_size, dtype = torch.long, device = self.device)
|
|
2368
|
-
|
|
2389
|
+
|
|
2369
2390
|
# noising past latent context
|
|
2370
2391
|
|
|
2371
2392
|
noised_context = latents.lerp(past_latents_context_noise, context_signal_noise) # the paragraph after eq (8)
|
|
2372
2393
|
|
|
2373
|
-
noised_latent_with_context, pack_context_shape = pack((noised_context, noised_latent), 'b * n d')
|
|
2394
|
+
noised_latent_with_context, pack_context_shape = pack((noised_context, noised_latent), 'b * v n d')
|
|
2374
2395
|
|
|
2375
2396
|
# handle proprio
|
|
2376
2397
|
|
|
@@ -2395,6 +2416,7 @@ class DynamicsWorldModel(Module):
|
|
|
2395
2416
|
proprio = noised_proprio_with_context,
|
|
2396
2417
|
time_kv_cache = time_kv_cache,
|
|
2397
2418
|
latent_is_noised = True,
|
|
2419
|
+
latent_has_view_dim = True,
|
|
2398
2420
|
return_pred_only = True,
|
|
2399
2421
|
return_intermediates = True,
|
|
2400
2422
|
)
|
|
@@ -2409,12 +2431,11 @@ class DynamicsWorldModel(Module):
|
|
|
2409
2431
|
|
|
2410
2432
|
# unpack pred
|
|
2411
2433
|
|
|
2412
|
-
_, pred = unpack(pred, pack_context_shape, 'b * n d')
|
|
2434
|
+
_, pred = unpack(pred, pack_context_shape, 'b * v n d')
|
|
2413
2435
|
|
|
2414
2436
|
if has_proprio:
|
|
2415
2437
|
_, pred_proprio = unpack(pred_proprio, pack_context_shape, 'b * d')
|
|
2416
2438
|
|
|
2417
|
-
|
|
2418
2439
|
# derive flow, based on whether in x-space or not
|
|
2419
2440
|
|
|
2420
2441
|
def denoise_step(pred, noised, signal_levels):
|
|
@@ -2507,12 +2528,26 @@ class DynamicsWorldModel(Module):
|
|
|
2507
2528
|
video = None
|
|
2508
2529
|
|
|
2509
2530
|
if return_decoded_video:
|
|
2531
|
+
|
|
2532
|
+
latents_for_video = rearrange(latents, 'b t v n d -> b v t n d')
|
|
2533
|
+
latents_for_video, unpack_view = pack_one(latents_for_video, '* t n d')
|
|
2534
|
+
|
|
2510
2535
|
video = self.video_tokenizer.decode(
|
|
2511
|
-
|
|
2536
|
+
latents_for_video,
|
|
2512
2537
|
height = image_height,
|
|
2513
2538
|
width = image_width
|
|
2514
2539
|
)
|
|
2515
2540
|
|
|
2541
|
+
video = unpack_view(video, '* t c vh vw')
|
|
2542
|
+
|
|
2543
|
+
# remove the lone view dimension
|
|
2544
|
+
|
|
2545
|
+
if not self.video_has_multi_view:
|
|
2546
|
+
latents = rearrange(latents, 'b t 1 ... -> b t ...')
|
|
2547
|
+
|
|
2548
|
+
if exists(video):
|
|
2549
|
+
video = rearrange(video, 'b 1 ... -> b ...')
|
|
2550
|
+
|
|
2516
2551
|
# only return video or latent if not requesting anything else, for first stage training
|
|
2517
2552
|
|
|
2518
2553
|
if not has_at_least_one(return_rewards_per_frame, return_agent_actions, has_proprio):
|
|
@@ -2553,8 +2588,9 @@ class DynamicsWorldModel(Module):
|
|
|
2553
2588
|
def forward(
|
|
2554
2589
|
self,
|
|
2555
2590
|
*,
|
|
2556
|
-
video = None, # (b c t vh vw)
|
|
2557
|
-
latents = None, # (b t n d) | (b t d)
|
|
2591
|
+
video = None, # (b v? c t vh vw)
|
|
2592
|
+
latents = None, # (b t v? n d) | (b t v? d)
|
|
2593
|
+
lens = None, # (b)
|
|
2558
2594
|
signal_levels = None, # () | (b) | (b t)
|
|
2559
2595
|
step_sizes = None, # () | (b)
|
|
2560
2596
|
step_sizes_log2 = None, # () | (b)
|
|
@@ -2572,21 +2608,39 @@ class DynamicsWorldModel(Module):
|
|
|
2572
2608
|
return_all_losses = False,
|
|
2573
2609
|
return_intermediates = False,
|
|
2574
2610
|
add_autoregressive_action_loss = False,
|
|
2575
|
-
update_loss_ema = None
|
|
2611
|
+
update_loss_ema = None,
|
|
2612
|
+
latent_has_view_dim = False
|
|
2576
2613
|
):
|
|
2577
2614
|
# handle video or latents
|
|
2578
2615
|
|
|
2579
2616
|
assert exists(video) ^ exists(latents)
|
|
2580
2617
|
|
|
2618
|
+
# standardize view dimension
|
|
2619
|
+
|
|
2620
|
+
if not self.video_has_multi_view:
|
|
2621
|
+
if exists(video):
|
|
2622
|
+
video = rearrange(video, 'b ... -> b 1 ...')
|
|
2623
|
+
|
|
2624
|
+
if exists(latents) and not latent_has_view_dim:
|
|
2625
|
+
latents = rearrange(latents, 'b t ... -> b t 1 ...')
|
|
2626
|
+
|
|
2627
|
+
# if raw video passed in, tokenize
|
|
2628
|
+
|
|
2581
2629
|
if exists(video):
|
|
2630
|
+
assert video.ndim == 6
|
|
2631
|
+
|
|
2632
|
+
video, unpack_views = pack_one(video, '* c t vh vw')
|
|
2582
2633
|
assert exists(self.video_tokenizer), 'video_tokenizer must be passed in if training from raw video on dynamics model'
|
|
2583
2634
|
|
|
2584
2635
|
latents = self.video_tokenizer.tokenize(video)
|
|
2636
|
+
latents = unpack_views(latents, '* t n d')
|
|
2637
|
+
latents = rearrange(latents, 'b v t n d -> b t v n d')
|
|
2585
2638
|
|
|
2586
|
-
if latents.ndim ==
|
|
2587
|
-
latents = rearrange(latents, 'b t d -> b t 1 d') # 1 latent edge case
|
|
2639
|
+
if latents.ndim == 4:
|
|
2640
|
+
latents = rearrange(latents, 'b t v d -> b t v 1 d') # 1 latent edge case
|
|
2588
2641
|
|
|
2589
2642
|
assert latents.shape[-2:] == self.latent_shape
|
|
2643
|
+
assert latents.shape[2] == self.num_video_views
|
|
2590
2644
|
|
|
2591
2645
|
# variables
|
|
2592
2646
|
|
|
@@ -2769,6 +2823,7 @@ class DynamicsWorldModel(Module):
|
|
|
2769
2823
|
# main function, needs to be defined as such for shortcut training - additional calls for consistency loss
|
|
2770
2824
|
|
|
2771
2825
|
def get_prediction(noised_latents, noised_proprio, signal_levels, step_sizes_log2, action_tokens, reward_tokens, agent_tokens, return_agent_tokens = False, return_time_kv_cache = False):
|
|
2826
|
+
|
|
2772
2827
|
# latents to spatial tokens
|
|
2773
2828
|
|
|
2774
2829
|
space_tokens = self.latents_to_spatial_tokens(noised_latents)
|
|
@@ -2969,7 +3024,19 @@ class DynamicsWorldModel(Module):
|
|
|
2969
3024
|
|
|
2970
3025
|
flow_losses = flow_losses * loss_weight
|
|
2971
3026
|
|
|
2972
|
-
|
|
3027
|
+
# handle variable lengths if needed
|
|
3028
|
+
|
|
3029
|
+
is_var_len = exists(lens)
|
|
3030
|
+
|
|
3031
|
+
if is_var_len:
|
|
3032
|
+
|
|
3033
|
+
loss_mask = lens_to_mask(lens, time)
|
|
3034
|
+
loss_mask_without_last = loss_mask[:, :-1]
|
|
3035
|
+
|
|
3036
|
+
flow_loss = flow_losses[loss_mask].mean()
|
|
3037
|
+
|
|
3038
|
+
else:
|
|
3039
|
+
flow_loss = flow_losses.mean()
|
|
2973
3040
|
|
|
2974
3041
|
# now take care of the agent token losses
|
|
2975
3042
|
|
|
@@ -2992,7 +3059,10 @@ class DynamicsWorldModel(Module):
|
|
|
2992
3059
|
|
|
2993
3060
|
reward_losses = reward_losses.masked_fill(reward_loss_mask, 0.)
|
|
2994
3061
|
|
|
2995
|
-
|
|
3062
|
+
if is_var_len:
|
|
3063
|
+
reward_loss = reward_losses[loss_mask_without_last].mean(dim = 0)
|
|
3064
|
+
else:
|
|
3065
|
+
reward_loss = reduce(reward_losses, '... mtp -> mtp', 'mean') # they sum across the prediction steps (mtp dimension) - eq(9)
|
|
2996
3066
|
|
|
2997
3067
|
# maybe autoregressive action loss
|
|
2998
3068
|
|
|
@@ -3035,12 +3105,20 @@ class DynamicsWorldModel(Module):
|
|
|
3035
3105
|
if exists(discrete_log_probs):
|
|
3036
3106
|
discrete_log_probs = discrete_log_probs.masked_fill(~discrete_mask[..., None], 0.)
|
|
3037
3107
|
|
|
3038
|
-
|
|
3108
|
+
if is_var_len:
|
|
3109
|
+
discrete_action_losses = rearrange(-discrete_log_probs, 'mtp b t na -> b t na mtp')
|
|
3110
|
+
discrete_action_loss = reduce(discrete_action_losses[loss_mask_without_last], '... mtp -> mtp', 'mean')
|
|
3111
|
+
else:
|
|
3112
|
+
discrete_action_loss = reduce(-discrete_log_probs, 'mtp b t na -> mtp', 'mean')
|
|
3039
3113
|
|
|
3040
3114
|
if exists(continuous_log_probs):
|
|
3041
3115
|
continuous_log_probs = continuous_log_probs.masked_fill(~continuous_mask[..., None], 0.)
|
|
3042
3116
|
|
|
3043
|
-
|
|
3117
|
+
if is_var_len:
|
|
3118
|
+
continuous_action_losses = rearrange(-continuous_log_probs, 'mtp b t na -> b t na mtp')
|
|
3119
|
+
continuous_action_loss = reduce(continuous_action_losses[loss_mask_without_last], '... mtp -> mtp', 'mean')
|
|
3120
|
+
else:
|
|
3121
|
+
continuous_action_loss = reduce(-continuous_log_probs, 'mtp b t na -> mtp', 'mean')
|
|
3044
3122
|
|
|
3045
3123
|
# handle loss normalization
|
|
3046
3124
|
|
|
@@ -16,6 +16,7 @@ def exists(v):
|
|
|
16
16
|
@param('num_residual_streams', (1, 4))
|
|
17
17
|
@param('add_reward_embed_to_agent_token', (False, True))
|
|
18
18
|
@param('use_time_kv_cache', (False, True))
|
|
19
|
+
@param('var_len', (False, True))
|
|
19
20
|
def test_e2e(
|
|
20
21
|
pred_orig_latent,
|
|
21
22
|
grouped_query_attn,
|
|
@@ -27,7 +28,8 @@ def test_e2e(
|
|
|
27
28
|
condition_on_actions,
|
|
28
29
|
num_residual_streams,
|
|
29
30
|
add_reward_embed_to_agent_token,
|
|
30
|
-
use_time_kv_cache
|
|
31
|
+
use_time_kv_cache,
|
|
32
|
+
var_len
|
|
31
33
|
):
|
|
32
34
|
from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel
|
|
33
35
|
|
|
@@ -95,8 +97,13 @@ def test_e2e(
|
|
|
95
97
|
if condition_on_actions:
|
|
96
98
|
actions = torch.randint(0, 4, (2, 3, 1))
|
|
97
99
|
|
|
100
|
+
lens = None
|
|
101
|
+
if var_len:
|
|
102
|
+
lens = torch.randint(1, 4, (2,))
|
|
103
|
+
|
|
98
104
|
flow_loss = dynamics(
|
|
99
105
|
**dynamics_input,
|
|
106
|
+
lens = lens,
|
|
100
107
|
tasks = tasks,
|
|
101
108
|
signal_levels = signal_levels,
|
|
102
109
|
step_sizes_log2 = step_sizes_log2,
|
|
@@ -668,7 +675,10 @@ def test_online_rl(
|
|
|
668
675
|
|
|
669
676
|
trainer(mock_env, num_episodes = 2, env_is_vectorized = vectorized)
|
|
670
677
|
|
|
671
|
-
|
|
678
|
+
@param('num_video_views', (1, 2))
|
|
679
|
+
def test_proprioception(
|
|
680
|
+
num_video_views
|
|
681
|
+
):
|
|
672
682
|
from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel
|
|
673
683
|
|
|
674
684
|
tokenizer = VideoTokenizer(
|
|
@@ -693,11 +703,17 @@ def test_proprioception():
|
|
|
693
703
|
dim_latent = 32,
|
|
694
704
|
dim_proprio = 21,
|
|
695
705
|
num_tasks = 4,
|
|
706
|
+
num_video_views = num_video_views,
|
|
696
707
|
num_discrete_actions = 4,
|
|
697
708
|
num_residual_streams = 1
|
|
698
709
|
)
|
|
699
710
|
|
|
700
|
-
|
|
711
|
+
if num_video_views > 1:
|
|
712
|
+
video_shape = (2, num_video_views, 3, 10, 256, 256)
|
|
713
|
+
else:
|
|
714
|
+
video_shape = (2, 3, 10, 256, 256)
|
|
715
|
+
|
|
716
|
+
video = torch.randn(*video_shape)
|
|
701
717
|
rewards = torch.randn(2, 10)
|
|
702
718
|
proprio = torch.randn(2, 10, 21)
|
|
703
719
|
discrete_actions = torch.randint(0, 4, (2, 10, 1))
|
|
@@ -714,8 +730,10 @@ def test_proprioception():
|
|
|
714
730
|
loss.backward()
|
|
715
731
|
|
|
716
732
|
generations = dynamics.generate(
|
|
717
|
-
|
|
733
|
+
10,
|
|
718
734
|
batch_size = 2,
|
|
735
|
+
return_decoded_video = True
|
|
719
736
|
)
|
|
720
737
|
|
|
721
738
|
assert exists(generations.proprio)
|
|
739
|
+
assert generations.video.shape == video_shape
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|