dreamer4 0.0.82__tar.gz → 0.0.84__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dreamer4 might be problematic. Click here for more details.
- {dreamer4-0.0.82 → dreamer4-0.0.84}/PKG-INFO +1 -1
- {dreamer4-0.0.82 → dreamer4-0.0.84}/dreamer4/dreamer4.py +45 -10
- {dreamer4-0.0.82 → dreamer4-0.0.84}/dreamer4/mocks.py +18 -6
- {dreamer4-0.0.82 → dreamer4-0.0.84}/pyproject.toml +1 -1
- {dreamer4-0.0.82 → dreamer4-0.0.84}/tests/test_dreamer.py +11 -3
- {dreamer4-0.0.82 → dreamer4-0.0.84}/.github/workflows/python-publish.yml +0 -0
- {dreamer4-0.0.82 → dreamer4-0.0.84}/.github/workflows/test.yml +0 -0
- {dreamer4-0.0.82 → dreamer4-0.0.84}/.gitignore +0 -0
- {dreamer4-0.0.82 → dreamer4-0.0.84}/LICENSE +0 -0
- {dreamer4-0.0.82 → dreamer4-0.0.84}/README.md +0 -0
- {dreamer4-0.0.82 → dreamer4-0.0.84}/dreamer4/__init__.py +0 -0
- {dreamer4-0.0.82 → dreamer4-0.0.84}/dreamer4/trainers.py +0 -0
- {dreamer4-0.0.82 → dreamer4-0.0.84}/dreamer4-fig2.png +0 -0
|
@@ -77,6 +77,7 @@ class Experience:
|
|
|
77
77
|
latents: Tensor
|
|
78
78
|
video: Tensor | None = None
|
|
79
79
|
proprio: Tensor | None = None
|
|
80
|
+
agent_embed: Tensor | None = None,
|
|
80
81
|
rewards: Tensor | None = None
|
|
81
82
|
actions: tuple[Tensor, Tensor] | None = None
|
|
82
83
|
log_probs: tuple[Tensor, Tensor] | None = None
|
|
@@ -2105,7 +2106,8 @@ class DynamicsWorldModel(Module):
|
|
|
2105
2106
|
step_size = 4,
|
|
2106
2107
|
max_timesteps = 16,
|
|
2107
2108
|
env_is_vectorized = False,
|
|
2108
|
-
use_time_kv_cache = True
|
|
2109
|
+
use_time_kv_cache = True,
|
|
2110
|
+
store_agent_embed = True
|
|
2109
2111
|
):
|
|
2110
2112
|
assert exists(self.video_tokenizer)
|
|
2111
2113
|
|
|
@@ -2130,16 +2132,23 @@ class DynamicsWorldModel(Module):
|
|
|
2130
2132
|
values = None
|
|
2131
2133
|
latents = None
|
|
2132
2134
|
|
|
2135
|
+
acc_agent_embed = None
|
|
2136
|
+
|
|
2133
2137
|
# keep track of termination, for setting the `is_truncated` field on Experience and for early stopping interaction with env
|
|
2134
2138
|
|
|
2135
2139
|
is_terminated = full((batch,), False, device = device)
|
|
2140
|
+
is_truncated = full((batch,), False, device = device)
|
|
2141
|
+
|
|
2136
2142
|
episode_lens = full((batch,), 0, device = device)
|
|
2137
2143
|
|
|
2138
2144
|
# maybe time kv cache
|
|
2139
2145
|
|
|
2140
2146
|
time_kv_cache = None
|
|
2141
2147
|
|
|
2142
|
-
|
|
2148
|
+
step_index = 0
|
|
2149
|
+
|
|
2150
|
+
while not is_terminated.all():
|
|
2151
|
+
step_index += 1
|
|
2143
2152
|
|
|
2144
2153
|
latents = self.video_tokenizer(video, return_latents = True)
|
|
2145
2154
|
|
|
@@ -2201,10 +2210,15 @@ class DynamicsWorldModel(Module):
|
|
|
2201
2210
|
|
|
2202
2211
|
if len(env_step_out) == 2:
|
|
2203
2212
|
next_frame, reward = env_step_out
|
|
2204
|
-
|
|
2213
|
+
terminated = full((batch,), False)
|
|
2214
|
+
truncated = full((batch,), False)
|
|
2205
2215
|
|
|
2206
2216
|
elif len(env_step_out) == 3:
|
|
2207
|
-
next_frame, reward,
|
|
2217
|
+
next_frame, reward, terminated = env_step_out
|
|
2218
|
+
truncated = full((batch,), False)
|
|
2219
|
+
|
|
2220
|
+
elif len(env_step_out) == 4:
|
|
2221
|
+
next_frame, reward, terminated, truncated = env_step_out
|
|
2208
2222
|
|
|
2209
2223
|
# update episode lens
|
|
2210
2224
|
|
|
@@ -2212,7 +2226,20 @@ class DynamicsWorldModel(Module):
|
|
|
2212
2226
|
|
|
2213
2227
|
# update `is_terminated`
|
|
2214
2228
|
|
|
2215
|
-
|
|
2229
|
+
# (1) - environment says it is terminated
|
|
2230
|
+
# (2) - previous step is truncated (this step is for bootstrap value)
|
|
2231
|
+
|
|
2232
|
+
is_terminated |= (terminated | is_truncated)
|
|
2233
|
+
|
|
2234
|
+
# update `is_truncated`
|
|
2235
|
+
|
|
2236
|
+
if step_index <= max_timesteps:
|
|
2237
|
+
is_truncated |= truncated
|
|
2238
|
+
|
|
2239
|
+
if step_index == max_timesteps:
|
|
2240
|
+
# if the step index is at the max time step allowed, set the truncated flag, if not already terminated
|
|
2241
|
+
|
|
2242
|
+
is_truncated |= ~is_terminated
|
|
2216
2243
|
|
|
2217
2244
|
# batch and time dimension
|
|
2218
2245
|
|
|
@@ -2228,10 +2255,7 @@ class DynamicsWorldModel(Module):
|
|
|
2228
2255
|
video = cat((video, next_frame), dim = 2)
|
|
2229
2256
|
rewards = safe_cat((rewards, reward), dim = 1)
|
|
2230
2257
|
|
|
2231
|
-
|
|
2232
|
-
|
|
2233
|
-
if is_terminated.all():
|
|
2234
|
-
break
|
|
2258
|
+
acc_agent_embed = safe_cat((acc_agent_embed, agent_embed), dim = 1)
|
|
2235
2259
|
|
|
2236
2260
|
# package up one experience for learning
|
|
2237
2261
|
|
|
@@ -2244,9 +2268,10 @@ class DynamicsWorldModel(Module):
|
|
|
2244
2268
|
actions = (discrete_actions, continuous_actions),
|
|
2245
2269
|
log_probs = (discrete_log_probs, continuous_log_probs),
|
|
2246
2270
|
values = values,
|
|
2271
|
+
agent_embed = acc_agent_embed if store_agent_embed else None,
|
|
2247
2272
|
step_size = step_size,
|
|
2248
2273
|
agent_index = agent_index,
|
|
2249
|
-
is_truncated =
|
|
2274
|
+
is_truncated = is_truncated,
|
|
2250
2275
|
lens = episode_lens,
|
|
2251
2276
|
is_from_world_model = False
|
|
2252
2277
|
)
|
|
@@ -2473,6 +2498,7 @@ class DynamicsWorldModel(Module):
|
|
|
2473
2498
|
return_agent_actions = False,
|
|
2474
2499
|
return_log_probs_and_values = False,
|
|
2475
2500
|
return_time_kv_cache = False,
|
|
2501
|
+
store_agent_embed = False
|
|
2476
2502
|
|
|
2477
2503
|
): # (b t n d) | (b c t h w)
|
|
2478
2504
|
|
|
@@ -2525,6 +2551,10 @@ class DynamicsWorldModel(Module):
|
|
|
2525
2551
|
decoded_continuous_log_probs = None
|
|
2526
2552
|
decoded_values = None
|
|
2527
2553
|
|
|
2554
|
+
# maybe store agent embed
|
|
2555
|
+
|
|
2556
|
+
acc_agent_embed = None
|
|
2557
|
+
|
|
2528
2558
|
# maybe return rewards
|
|
2529
2559
|
|
|
2530
2560
|
decoded_rewards = None
|
|
@@ -2633,6 +2663,10 @@ class DynamicsWorldModel(Module):
|
|
|
2633
2663
|
|
|
2634
2664
|
decoded_rewards = cat((decoded_rewards, pred_reward), dim = 1)
|
|
2635
2665
|
|
|
2666
|
+
# maybe store agent embed
|
|
2667
|
+
|
|
2668
|
+
acc_agent_embed = safe_cat((acc_agent_embed, agent_embed), dim = 1)
|
|
2669
|
+
|
|
2636
2670
|
# decode the agent actions if needed
|
|
2637
2671
|
|
|
2638
2672
|
if return_agent_actions:
|
|
@@ -2729,6 +2763,7 @@ class DynamicsWorldModel(Module):
|
|
|
2729
2763
|
latents = latents,
|
|
2730
2764
|
video = video,
|
|
2731
2765
|
proprio = proprio if has_proprio else None,
|
|
2766
|
+
agent_embed = acc_agent_embed if store_agent_embed else None,
|
|
2732
2767
|
step_size = step_size,
|
|
2733
2768
|
agent_index = agent_index,
|
|
2734
2769
|
lens = experience_lens,
|
|
@@ -22,7 +22,9 @@ class MockEnv(Module):
|
|
|
22
22
|
num_envs = 1,
|
|
23
23
|
vectorized = False,
|
|
24
24
|
terminate_after_step = None,
|
|
25
|
-
rand_terminate_prob = 0.05
|
|
25
|
+
rand_terminate_prob = 0.05,
|
|
26
|
+
can_truncate = False,
|
|
27
|
+
rand_truncate_prob = 0.05,
|
|
26
28
|
):
|
|
27
29
|
super().__init__()
|
|
28
30
|
self.image_shape = image_shape
|
|
@@ -32,12 +34,15 @@ class MockEnv(Module):
|
|
|
32
34
|
self.vectorized = vectorized
|
|
33
35
|
assert not (vectorized and num_envs == 1)
|
|
34
36
|
|
|
35
|
-
# mocking termination
|
|
37
|
+
# mocking termination and truncation
|
|
36
38
|
|
|
37
39
|
self.can_terminate = exists(terminate_after_step)
|
|
38
40
|
self.terminate_after_step = terminate_after_step
|
|
39
41
|
self.rand_terminate_prob = rand_terminate_prob
|
|
40
42
|
|
|
43
|
+
self.can_truncate = can_truncate
|
|
44
|
+
self.rand_truncate_prob = rand_truncate_prob
|
|
45
|
+
|
|
41
46
|
self.register_buffer('_step', tensor(0))
|
|
42
47
|
|
|
43
48
|
def get_random_state(self):
|
|
@@ -72,14 +77,21 @@ class MockEnv(Module):
|
|
|
72
77
|
|
|
73
78
|
out = (state, reward)
|
|
74
79
|
|
|
80
|
+
|
|
75
81
|
if self.can_terminate:
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
)
|
|
82
|
+
shape = (self.num_envs,) if self.vectorized else (1,)
|
|
83
|
+
valid_step = self._step > self.terminate_after_step
|
|
84
|
+
|
|
85
|
+
terminate = (torch.rand(shape) < self.rand_terminate_prob) & valid_step
|
|
80
86
|
|
|
81
87
|
out = (*out, terminate)
|
|
82
88
|
|
|
89
|
+
# maybe truncation
|
|
90
|
+
|
|
91
|
+
if self.can_truncate:
|
|
92
|
+
truncate = (torch.rand(shape) < self.rand_truncate_prob) & valid_step & ~terminate
|
|
93
|
+
out = (*out, truncate)
|
|
94
|
+
|
|
83
95
|
self._step.add_(1)
|
|
84
96
|
|
|
85
97
|
return out
|
|
@@ -613,10 +613,14 @@ def test_cache_generate():
|
|
|
613
613
|
@param('vectorized', (False, True))
|
|
614
614
|
@param('use_signed_advantage', (False, True))
|
|
615
615
|
@param('env_can_terminate', (False, True))
|
|
616
|
+
@param('env_can_truncate', (False, True))
|
|
617
|
+
@param('store_agent_embed', (False, True))
|
|
616
618
|
def test_online_rl(
|
|
617
619
|
vectorized,
|
|
618
620
|
use_signed_advantage,
|
|
619
|
-
env_can_terminate
|
|
621
|
+
env_can_terminate,
|
|
622
|
+
env_can_truncate,
|
|
623
|
+
store_agent_embed
|
|
620
624
|
):
|
|
621
625
|
from dreamer4.dreamer4 import DynamicsWorldModel, VideoTokenizer
|
|
622
626
|
|
|
@@ -656,16 +660,20 @@ def test_online_rl(
|
|
|
656
660
|
vectorized = vectorized,
|
|
657
661
|
num_envs = 4,
|
|
658
662
|
terminate_after_step = 2 if env_can_terminate else None,
|
|
663
|
+
can_truncate = env_can_truncate,
|
|
659
664
|
rand_terminate_prob = 0.1
|
|
660
665
|
)
|
|
661
666
|
|
|
662
667
|
# manually
|
|
663
668
|
|
|
664
|
-
one_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 8, env_is_vectorized = vectorized)
|
|
665
|
-
another_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16, env_is_vectorized = vectorized)
|
|
669
|
+
one_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 8, env_is_vectorized = vectorized, store_agent_embed = store_agent_embed)
|
|
670
|
+
another_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16, env_is_vectorized = vectorized, store_agent_embed = store_agent_embed)
|
|
666
671
|
|
|
667
672
|
combined_experience = combine_experiences([one_experience, another_experience])
|
|
668
673
|
|
|
674
|
+
if store_agent_embed:
|
|
675
|
+
assert exists(combined_experience.agent_embed)
|
|
676
|
+
|
|
669
677
|
actor_loss, critic_loss = world_model_and_policy.learn_from_experience(combined_experience, use_signed_advantage = use_signed_advantage)
|
|
670
678
|
|
|
671
679
|
actor_loss.backward()
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|