dreamer4 0.0.83__tar.gz → 0.0.85__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dreamer4 might be problematic. Click here for more details.
- {dreamer4-0.0.83 → dreamer4-0.0.85}/PKG-INFO +1 -1
- {dreamer4-0.0.83 → dreamer4-0.0.85}/dreamer4/dreamer4.py +18 -1
- {dreamer4-0.0.83 → dreamer4-0.0.85}/pyproject.toml +1 -1
- {dreamer4-0.0.83 → dreamer4-0.0.85}/tests/test_dreamer.py +8 -3
- {dreamer4-0.0.83 → dreamer4-0.0.85}/.github/workflows/python-publish.yml +0 -0
- {dreamer4-0.0.83 → dreamer4-0.0.85}/.github/workflows/test.yml +0 -0
- {dreamer4-0.0.83 → dreamer4-0.0.85}/.gitignore +0 -0
- {dreamer4-0.0.83 → dreamer4-0.0.85}/LICENSE +0 -0
- {dreamer4-0.0.83 → dreamer4-0.0.85}/README.md +0 -0
- {dreamer4-0.0.83 → dreamer4-0.0.85}/dreamer4/__init__.py +0 -0
- {dreamer4-0.0.83 → dreamer4-0.0.85}/dreamer4/mocks.py +0 -0
- {dreamer4-0.0.83 → dreamer4-0.0.85}/dreamer4/trainers.py +0 -0
- {dreamer4-0.0.83 → dreamer4-0.0.85}/dreamer4-fig2.png +0 -0
|
@@ -77,6 +77,7 @@ class Experience:
|
|
|
77
77
|
latents: Tensor
|
|
78
78
|
video: Tensor | None = None
|
|
79
79
|
proprio: Tensor | None = None
|
|
80
|
+
agent_embed: Tensor | None = None,
|
|
80
81
|
rewards: Tensor | None = None
|
|
81
82
|
actions: tuple[Tensor, Tensor] | None = None
|
|
82
83
|
log_probs: tuple[Tensor, Tensor] | None = None
|
|
@@ -2105,7 +2106,8 @@ class DynamicsWorldModel(Module):
|
|
|
2105
2106
|
step_size = 4,
|
|
2106
2107
|
max_timesteps = 16,
|
|
2107
2108
|
env_is_vectorized = False,
|
|
2108
|
-
use_time_kv_cache = True
|
|
2109
|
+
use_time_kv_cache = True,
|
|
2110
|
+
store_agent_embed = False
|
|
2109
2111
|
):
|
|
2110
2112
|
assert exists(self.video_tokenizer)
|
|
2111
2113
|
|
|
@@ -2130,6 +2132,8 @@ class DynamicsWorldModel(Module):
|
|
|
2130
2132
|
values = None
|
|
2131
2133
|
latents = None
|
|
2132
2134
|
|
|
2135
|
+
acc_agent_embed = None
|
|
2136
|
+
|
|
2133
2137
|
# keep track of termination, for setting the `is_truncated` field on Experience and for early stopping interaction with env
|
|
2134
2138
|
|
|
2135
2139
|
is_terminated = full((batch,), False, device = device)
|
|
@@ -2251,6 +2255,8 @@ class DynamicsWorldModel(Module):
|
|
|
2251
2255
|
video = cat((video, next_frame), dim = 2)
|
|
2252
2256
|
rewards = safe_cat((rewards, reward), dim = 1)
|
|
2253
2257
|
|
|
2258
|
+
acc_agent_embed = safe_cat((acc_agent_embed, agent_embed), dim = 1)
|
|
2259
|
+
|
|
2254
2260
|
# package up one experience for learning
|
|
2255
2261
|
|
|
2256
2262
|
batch, device = latents.shape[0], latents.device
|
|
@@ -2262,6 +2268,7 @@ class DynamicsWorldModel(Module):
|
|
|
2262
2268
|
actions = (discrete_actions, continuous_actions),
|
|
2263
2269
|
log_probs = (discrete_log_probs, continuous_log_probs),
|
|
2264
2270
|
values = values,
|
|
2271
|
+
agent_embed = acc_agent_embed if store_agent_embed else None,
|
|
2265
2272
|
step_size = step_size,
|
|
2266
2273
|
agent_index = agent_index,
|
|
2267
2274
|
is_truncated = is_truncated,
|
|
@@ -2491,6 +2498,7 @@ class DynamicsWorldModel(Module):
|
|
|
2491
2498
|
return_agent_actions = False,
|
|
2492
2499
|
return_log_probs_and_values = False,
|
|
2493
2500
|
return_time_kv_cache = False,
|
|
2501
|
+
store_agent_embed = False
|
|
2494
2502
|
|
|
2495
2503
|
): # (b t n d) | (b c t h w)
|
|
2496
2504
|
|
|
@@ -2543,6 +2551,10 @@ class DynamicsWorldModel(Module):
|
|
|
2543
2551
|
decoded_continuous_log_probs = None
|
|
2544
2552
|
decoded_values = None
|
|
2545
2553
|
|
|
2554
|
+
# maybe store agent embed
|
|
2555
|
+
|
|
2556
|
+
acc_agent_embed = None
|
|
2557
|
+
|
|
2546
2558
|
# maybe return rewards
|
|
2547
2559
|
|
|
2548
2560
|
decoded_rewards = None
|
|
@@ -2651,6 +2663,10 @@ class DynamicsWorldModel(Module):
|
|
|
2651
2663
|
|
|
2652
2664
|
decoded_rewards = cat((decoded_rewards, pred_reward), dim = 1)
|
|
2653
2665
|
|
|
2666
|
+
# maybe store agent embed
|
|
2667
|
+
|
|
2668
|
+
acc_agent_embed = safe_cat((acc_agent_embed, agent_embed), dim = 1)
|
|
2669
|
+
|
|
2654
2670
|
# decode the agent actions if needed
|
|
2655
2671
|
|
|
2656
2672
|
if return_agent_actions:
|
|
@@ -2747,6 +2763,7 @@ class DynamicsWorldModel(Module):
|
|
|
2747
2763
|
latents = latents,
|
|
2748
2764
|
video = video,
|
|
2749
2765
|
proprio = proprio if has_proprio else None,
|
|
2766
|
+
agent_embed = acc_agent_embed if store_agent_embed else None,
|
|
2750
2767
|
step_size = step_size,
|
|
2751
2768
|
agent_index = agent_index,
|
|
2752
2769
|
lens = experience_lens,
|
|
@@ -614,11 +614,13 @@ def test_cache_generate():
|
|
|
614
614
|
@param('use_signed_advantage', (False, True))
|
|
615
615
|
@param('env_can_terminate', (False, True))
|
|
616
616
|
@param('env_can_truncate', (False, True))
|
|
617
|
+
@param('store_agent_embed', (False, True))
|
|
617
618
|
def test_online_rl(
|
|
618
619
|
vectorized,
|
|
619
620
|
use_signed_advantage,
|
|
620
621
|
env_can_terminate,
|
|
621
|
-
env_can_truncate
|
|
622
|
+
env_can_truncate,
|
|
623
|
+
store_agent_embed
|
|
622
624
|
):
|
|
623
625
|
from dreamer4.dreamer4 import DynamicsWorldModel, VideoTokenizer
|
|
624
626
|
|
|
@@ -664,11 +666,14 @@ def test_online_rl(
|
|
|
664
666
|
|
|
665
667
|
# manually
|
|
666
668
|
|
|
667
|
-
one_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 8, env_is_vectorized = vectorized)
|
|
668
|
-
another_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16, env_is_vectorized = vectorized)
|
|
669
|
+
one_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 8, env_is_vectorized = vectorized, store_agent_embed = store_agent_embed)
|
|
670
|
+
another_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16, env_is_vectorized = vectorized, store_agent_embed = store_agent_embed)
|
|
669
671
|
|
|
670
672
|
combined_experience = combine_experiences([one_experience, another_experience])
|
|
671
673
|
|
|
674
|
+
if store_agent_embed:
|
|
675
|
+
assert exists(combined_experience.agent_embed)
|
|
676
|
+
|
|
672
677
|
actor_loss, critic_loss = world_model_and_policy.learn_from_experience(combined_experience, use_signed_advantage = use_signed_advantage)
|
|
673
678
|
|
|
674
679
|
actor_loss.backward()
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|