dreamer4 0.0.70__py3-none-any.whl → 0.0.71__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dreamer4 might be problematic. Click here for more details.
- dreamer4/dreamer4.py +61 -16
- {dreamer4-0.0.70.dist-info → dreamer4-0.0.71.dist-info}/METADATA +1 -1
- dreamer4-0.0.71.dist-info/RECORD +8 -0
- dreamer4-0.0.70.dist-info/RECORD +0 -8
- {dreamer4-0.0.70.dist-info → dreamer4-0.0.71.dist-info}/WHEEL +0 -0
- {dreamer4-0.0.70.dist-info → dreamer4-0.0.71.dist-info}/licenses/LICENSE +0 -0
dreamer4/dreamer4.py
CHANGED
|
@@ -45,6 +45,7 @@ from assoc_scan import AssocScan
|
|
|
45
45
|
# vc - video channels
|
|
46
46
|
# vh, vw - video height and width
|
|
47
47
|
# mtp - multi token prediction length
|
|
48
|
+
# v - video viewpoints
|
|
48
49
|
|
|
49
50
|
import einx
|
|
50
51
|
from einx import add, multiply
|
|
@@ -1735,6 +1736,7 @@ class DynamicsWorldModel(Module):
|
|
|
1735
1736
|
num_latent_tokens = None,
|
|
1736
1737
|
num_agents = 1,
|
|
1737
1738
|
num_tasks = 0,
|
|
1739
|
+
num_video_views = 1,
|
|
1738
1740
|
dim_proprio = None,
|
|
1739
1741
|
reward_encoder_kwargs: dict = dict(),
|
|
1740
1742
|
depth = 4,
|
|
@@ -1800,7 +1802,7 @@ class DynamicsWorldModel(Module):
|
|
|
1800
1802
|
)
|
|
1801
1803
|
|
|
1802
1804
|
self.to_latent_pred = Sequential(
|
|
1803
|
-
Reduce('b t n s d -> b t n d', 'mean'),
|
|
1805
|
+
Reduce('b t v n s d -> b t v n d', 'mean'),
|
|
1804
1806
|
RMSNorm(dim),
|
|
1805
1807
|
LinearNoBias(dim, dim_latent)
|
|
1806
1808
|
)
|
|
@@ -1810,17 +1812,27 @@ class DynamicsWorldModel(Module):
|
|
|
1810
1812
|
latent_tokens_to_space = num_latent_tokens // num_spatial_tokens
|
|
1811
1813
|
|
|
1812
1814
|
self.latents_to_spatial_tokens = Sequential(
|
|
1813
|
-
Rearrange('
|
|
1815
|
+
Rearrange('... n d -> ... (n d)'),
|
|
1814
1816
|
Linear(num_latent_tokens * dim_latent, dim * num_spatial_tokens),
|
|
1815
|
-
Rearrange('
|
|
1817
|
+
Rearrange('... (s d) -> ... s d', s = num_spatial_tokens)
|
|
1816
1818
|
)
|
|
1817
1819
|
|
|
1818
1820
|
self.to_latent_pred = Sequential(
|
|
1819
1821
|
RMSNorm(dim),
|
|
1820
1822
|
LinearNoBias(dim, dim_latent * latent_tokens_to_space),
|
|
1821
|
-
Rearrange('b t s (n d) -> b t (s n) d', n = latent_tokens_to_space)
|
|
1823
|
+
Rearrange('b t v s (n d) -> b t v (s n) d', n = latent_tokens_to_space)
|
|
1822
1824
|
)
|
|
1823
1825
|
|
|
1826
|
+
# number of video views, for robotics, which could have third person + wrist camera at least
|
|
1827
|
+
|
|
1828
|
+
assert num_video_views >= 1
|
|
1829
|
+
self.video_has_multi_view = num_video_views > 1
|
|
1830
|
+
|
|
1831
|
+
self.num_video_views = num_video_views
|
|
1832
|
+
|
|
1833
|
+
if self.video_has_multi_view:
|
|
1834
|
+
self.view_emb = nn.Parameter(torch.randn(num_video_views, dim) * 1e-2)
|
|
1835
|
+
|
|
1824
1836
|
# proprioception
|
|
1825
1837
|
|
|
1826
1838
|
self.has_proprio = exists(dim_proprio)
|
|
@@ -2318,7 +2330,7 @@ class DynamicsWorldModel(Module):
|
|
|
2318
2330
|
# denoising
|
|
2319
2331
|
# teacher forcing to start with
|
|
2320
2332
|
|
|
2321
|
-
latents = empty((batch_size, 0, *latent_shape), device = self.device)
|
|
2333
|
+
latents = empty((batch_size, 0, self.num_video_views, *latent_shape), device = self.device)
|
|
2322
2334
|
|
|
2323
2335
|
past_latents_context_noise = latents.clone()
|
|
2324
2336
|
|
|
@@ -2354,7 +2366,7 @@ class DynamicsWorldModel(Module):
|
|
|
2354
2366
|
|
|
2355
2367
|
curr_time_steps = latents.shape[1]
|
|
2356
2368
|
|
|
2357
|
-
noised_latent = randn((batch_size, 1, *latent_shape), device = self.device)
|
|
2369
|
+
noised_latent = randn((batch_size, 1, self.num_video_views, *latent_shape), device = self.device)
|
|
2358
2370
|
|
|
2359
2371
|
noised_proprio = None
|
|
2360
2372
|
|
|
@@ -2365,12 +2377,12 @@ class DynamicsWorldModel(Module):
|
|
|
2365
2377
|
is_last_step = (step + 1) == num_steps
|
|
2366
2378
|
|
|
2367
2379
|
signal_levels = full((batch_size, 1), step * step_size, dtype = torch.long, device = self.device)
|
|
2368
|
-
|
|
2380
|
+
|
|
2369
2381
|
# noising past latent context
|
|
2370
2382
|
|
|
2371
2383
|
noised_context = latents.lerp(past_latents_context_noise, context_signal_noise) # the paragraph after eq (8)
|
|
2372
2384
|
|
|
2373
|
-
noised_latent_with_context, pack_context_shape = pack((noised_context, noised_latent), 'b * n d')
|
|
2385
|
+
noised_latent_with_context, pack_context_shape = pack((noised_context, noised_latent), 'b * v n d')
|
|
2374
2386
|
|
|
2375
2387
|
# handle proprio
|
|
2376
2388
|
|
|
@@ -2395,6 +2407,7 @@ class DynamicsWorldModel(Module):
|
|
|
2395
2407
|
proprio = noised_proprio_with_context,
|
|
2396
2408
|
time_kv_cache = time_kv_cache,
|
|
2397
2409
|
latent_is_noised = True,
|
|
2410
|
+
latent_has_view_dim = True,
|
|
2398
2411
|
return_pred_only = True,
|
|
2399
2412
|
return_intermediates = True,
|
|
2400
2413
|
)
|
|
@@ -2409,12 +2422,11 @@ class DynamicsWorldModel(Module):
|
|
|
2409
2422
|
|
|
2410
2423
|
# unpack pred
|
|
2411
2424
|
|
|
2412
|
-
_, pred = unpack(pred, pack_context_shape, 'b * n d')
|
|
2425
|
+
_, pred = unpack(pred, pack_context_shape, 'b * v n d')
|
|
2413
2426
|
|
|
2414
2427
|
if has_proprio:
|
|
2415
2428
|
_, pred_proprio = unpack(pred_proprio, pack_context_shape, 'b * d')
|
|
2416
2429
|
|
|
2417
|
-
|
|
2418
2430
|
# derive flow, based on whether in x-space or not
|
|
2419
2431
|
|
|
2420
2432
|
def denoise_step(pred, noised, signal_levels):
|
|
@@ -2507,12 +2519,26 @@ class DynamicsWorldModel(Module):
|
|
|
2507
2519
|
video = None
|
|
2508
2520
|
|
|
2509
2521
|
if return_decoded_video:
|
|
2522
|
+
|
|
2523
|
+
latents_for_video = rearrange(latents, 'b t v n d -> b v t n d')
|
|
2524
|
+
latents_for_video, unpack_view = pack_one(latents_for_video, '* t n d')
|
|
2525
|
+
|
|
2510
2526
|
video = self.video_tokenizer.decode(
|
|
2511
|
-
|
|
2527
|
+
latents_for_video,
|
|
2512
2528
|
height = image_height,
|
|
2513
2529
|
width = image_width
|
|
2514
2530
|
)
|
|
2515
2531
|
|
|
2532
|
+
video = unpack_view(video, '* t c vh vw')
|
|
2533
|
+
|
|
2534
|
+
# remove the lone view dimension
|
|
2535
|
+
|
|
2536
|
+
if not self.video_has_multi_view:
|
|
2537
|
+
latents = rearrange(latents, 'b t 1 ... -> b t ...')
|
|
2538
|
+
|
|
2539
|
+
if exists(video):
|
|
2540
|
+
video = rearrange(video, 'b 1 ... -> b ...')
|
|
2541
|
+
|
|
2516
2542
|
# only return video or latent if not requesting anything else, for first stage training
|
|
2517
2543
|
|
|
2518
2544
|
if not has_at_least_one(return_rewards_per_frame, return_agent_actions, has_proprio):
|
|
@@ -2553,8 +2579,8 @@ class DynamicsWorldModel(Module):
|
|
|
2553
2579
|
def forward(
|
|
2554
2580
|
self,
|
|
2555
2581
|
*,
|
|
2556
|
-
video = None, # (b c t vh vw)
|
|
2557
|
-
latents = None, # (b t n d) | (b t d)
|
|
2582
|
+
video = None, # (b v? c t vh vw)
|
|
2583
|
+
latents = None, # (b t v? n d) | (b t v? d)
|
|
2558
2584
|
signal_levels = None, # () | (b) | (b t)
|
|
2559
2585
|
step_sizes = None, # () | (b)
|
|
2560
2586
|
step_sizes_log2 = None, # () | (b)
|
|
@@ -2572,21 +2598,39 @@ class DynamicsWorldModel(Module):
|
|
|
2572
2598
|
return_all_losses = False,
|
|
2573
2599
|
return_intermediates = False,
|
|
2574
2600
|
add_autoregressive_action_loss = False,
|
|
2575
|
-
update_loss_ema = None
|
|
2601
|
+
update_loss_ema = None,
|
|
2602
|
+
latent_has_view_dim = False
|
|
2576
2603
|
):
|
|
2577
2604
|
# handle video or latents
|
|
2578
2605
|
|
|
2579
2606
|
assert exists(video) ^ exists(latents)
|
|
2580
2607
|
|
|
2608
|
+
# standardize view dimension
|
|
2609
|
+
|
|
2610
|
+
if not self.video_has_multi_view:
|
|
2611
|
+
if exists(video):
|
|
2612
|
+
video = rearrange(video, 'b ... -> b 1 ...')
|
|
2613
|
+
|
|
2614
|
+
if exists(latents) and not latent_has_view_dim:
|
|
2615
|
+
latents = rearrange(latents, 'b t ... -> b t 1 ...')
|
|
2616
|
+
|
|
2617
|
+
# if raw video passed in, tokenize
|
|
2618
|
+
|
|
2581
2619
|
if exists(video):
|
|
2620
|
+
assert video.ndim == 6
|
|
2621
|
+
|
|
2622
|
+
video, unpack_views = pack_one(video, '* c t vh vw')
|
|
2582
2623
|
assert exists(self.video_tokenizer), 'video_tokenizer must be passed in if training from raw video on dynamics model'
|
|
2583
2624
|
|
|
2584
2625
|
latents = self.video_tokenizer.tokenize(video)
|
|
2626
|
+
latents = unpack_views(latents, '* t n d')
|
|
2627
|
+
latents = rearrange(latents, 'b v t n d -> b t v n d')
|
|
2585
2628
|
|
|
2586
|
-
if latents.ndim ==
|
|
2587
|
-
latents = rearrange(latents, 'b t d -> b t 1 d') # 1 latent edge case
|
|
2629
|
+
if latents.ndim == 4:
|
|
2630
|
+
latents = rearrange(latents, 'b t v d -> b t v 1 d') # 1 latent edge case
|
|
2588
2631
|
|
|
2589
2632
|
assert latents.shape[-2:] == self.latent_shape
|
|
2633
|
+
assert latents.shape[2] == self.num_video_views
|
|
2590
2634
|
|
|
2591
2635
|
# variables
|
|
2592
2636
|
|
|
@@ -2769,6 +2813,7 @@ class DynamicsWorldModel(Module):
|
|
|
2769
2813
|
# main function, needs to be defined as such for shortcut training - additional calls for consistency loss
|
|
2770
2814
|
|
|
2771
2815
|
def get_prediction(noised_latents, noised_proprio, signal_levels, step_sizes_log2, action_tokens, reward_tokens, agent_tokens, return_agent_tokens = False, return_time_kv_cache = False):
|
|
2816
|
+
|
|
2772
2817
|
# latents to spatial tokens
|
|
2773
2818
|
|
|
2774
2819
|
space_tokens = self.latents_to_spatial_tokens(noised_latents)
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
dreamer4/__init__.py,sha256=Jssh1obzDRtTfBLZl36kXge1cIQlMjf_8DyjPulvKSk,183
|
|
2
|
+
dreamer4/dreamer4.py,sha256=ywbJM384Z1VXQOHyv5RwUaCmsKAJyo2CmaDpVknml2c,103385
|
|
3
|
+
dreamer4/mocks.py,sha256=Oi91Yv1oK0E-Wz-KDkf79xoyWzIXCvMLCr0WYCpJDLA,1482
|
|
4
|
+
dreamer4/trainers.py,sha256=898ye9Y1mqxGZnU_gfQS6pECibZwwyA43sL7wK_JHAU,13993
|
|
5
|
+
dreamer4-0.0.71.dist-info/METADATA,sha256=wUWSPxR5xZfSsICp3GLE-p25hOSdqcdV-7Zh_vAYV64,3065
|
|
6
|
+
dreamer4-0.0.71.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
+
dreamer4-0.0.71.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
+
dreamer4-0.0.71.dist-info/RECORD,,
|
dreamer4-0.0.70.dist-info/RECORD
DELETED
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
dreamer4/__init__.py,sha256=Jssh1obzDRtTfBLZl36kXge1cIQlMjf_8DyjPulvKSk,183
|
|
2
|
-
dreamer4/dreamer4.py,sha256=I1wuhtl9Oqi52x8PAsu_8DXFsJZqwoeFTIPhZGpBLNY,101718
|
|
3
|
-
dreamer4/mocks.py,sha256=Oi91Yv1oK0E-Wz-KDkf79xoyWzIXCvMLCr0WYCpJDLA,1482
|
|
4
|
-
dreamer4/trainers.py,sha256=898ye9Y1mqxGZnU_gfQS6pECibZwwyA43sL7wK_JHAU,13993
|
|
5
|
-
dreamer4-0.0.70.dist-info/METADATA,sha256=N4xH-8IdKAxNmPOa2K97TMGvY4As2-gxbSKzOzf9IK4,3065
|
|
6
|
-
dreamer4-0.0.70.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
-
dreamer4-0.0.70.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
-
dreamer4-0.0.70.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|