dreamer4 0.0.80__py3-none-any.whl → 0.0.82__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dreamer4 might be problematic. Click here for more details.
- dreamer4/dreamer4.py +75 -4
- dreamer4/mocks.py +31 -8
- {dreamer4-0.0.80.dist-info → dreamer4-0.0.82.dist-info}/METADATA +1 -1
- dreamer4-0.0.82.dist-info/RECORD +8 -0
- dreamer4-0.0.80.dist-info/RECORD +0 -8
- {dreamer4-0.0.80.dist-info → dreamer4-0.0.82.dist-info}/WHEEL +0 -0
- {dreamer4-0.0.80.dist-info → dreamer4-0.0.82.dist-info}/licenses/LICENSE +0 -0
dreamer4/dreamer4.py
CHANGED
|
@@ -1817,6 +1817,9 @@ class DynamicsWorldModel(Module):
|
|
|
1817
1817
|
continuous_action_loss_weight: float | list[float] = 1.,
|
|
1818
1818
|
num_latent_genes = 0, # for carrying out evolution within the dreams https://web3.arxiv.org/abs/2503.19037
|
|
1819
1819
|
num_residual_streams = 1,
|
|
1820
|
+
keep_reward_ema_stats = False,
|
|
1821
|
+
reward_ema_decay = 0.998,
|
|
1822
|
+
reward_quantile_filter = (0.05, 0.95),
|
|
1820
1823
|
gae_discount_factor = 0.997,
|
|
1821
1824
|
gae_lambda = 0.95,
|
|
1822
1825
|
ppo_eps_clip = 0.2,
|
|
@@ -2022,6 +2025,16 @@ class DynamicsWorldModel(Module):
|
|
|
2022
2025
|
self.value_clip = value_clip
|
|
2023
2026
|
self.policy_entropy_weight = value_clip
|
|
2024
2027
|
|
|
2028
|
+
# rewards related
|
|
2029
|
+
|
|
2030
|
+
self.keep_reward_ema_stats = keep_reward_ema_stats
|
|
2031
|
+
self.reward_ema_decay = reward_ema_decay
|
|
2032
|
+
|
|
2033
|
+
self.register_buffer('reward_quantile_filter', tensor(reward_quantile_filter), persistent = False)
|
|
2034
|
+
|
|
2035
|
+
self.register_buffer('ema_returns_mean', tensor(0.))
|
|
2036
|
+
self.register_buffer('ema_returns_var', tensor(1.))
|
|
2037
|
+
|
|
2025
2038
|
# loss related
|
|
2026
2039
|
|
|
2027
2040
|
self.flow_loss_normalizer = LossNormalizer(1)
|
|
@@ -2105,6 +2118,8 @@ class DynamicsWorldModel(Module):
|
|
|
2105
2118
|
else:
|
|
2106
2119
|
video = rearrange(init_frame, 'c vh vw -> 1 c 1 vh vw')
|
|
2107
2120
|
|
|
2121
|
+
batch, device = video.shape[0], video.device
|
|
2122
|
+
|
|
2108
2123
|
# accumulate
|
|
2109
2124
|
|
|
2110
2125
|
rewards = None
|
|
@@ -2115,6 +2130,11 @@ class DynamicsWorldModel(Module):
|
|
|
2115
2130
|
values = None
|
|
2116
2131
|
latents = None
|
|
2117
2132
|
|
|
2133
|
+
# keep track of termination, for setting the `is_truncated` field on Experience and for early stopping interaction with env
|
|
2134
|
+
|
|
2135
|
+
is_terminated = full((batch,), False, device = device)
|
|
2136
|
+
episode_lens = full((batch,), 0, device = device)
|
|
2137
|
+
|
|
2118
2138
|
# maybe time kv cache
|
|
2119
2139
|
|
|
2120
2140
|
time_kv_cache = None
|
|
@@ -2177,7 +2197,22 @@ class DynamicsWorldModel(Module):
|
|
|
2177
2197
|
|
|
2178
2198
|
# pass the sampled action to the environment and get back next state and reward
|
|
2179
2199
|
|
|
2180
|
-
|
|
2200
|
+
env_step_out = env.step((sampled_discrete_actions, sampled_continuous_actions))
|
|
2201
|
+
|
|
2202
|
+
if len(env_step_out) == 2:
|
|
2203
|
+
next_frame, reward = env_step_out
|
|
2204
|
+
terminate = full((batch,), False)
|
|
2205
|
+
|
|
2206
|
+
elif len(env_step_out) == 3:
|
|
2207
|
+
next_frame, reward, terminate = env_step_out
|
|
2208
|
+
|
|
2209
|
+
# update episode lens
|
|
2210
|
+
|
|
2211
|
+
episode_lens = torch.where(is_terminated, episode_lens, episode_lens + 1)
|
|
2212
|
+
|
|
2213
|
+
# update `is_terminated`
|
|
2214
|
+
|
|
2215
|
+
is_terminated |= terminate
|
|
2181
2216
|
|
|
2182
2217
|
# batch and time dimension
|
|
2183
2218
|
|
|
@@ -2193,6 +2228,11 @@ class DynamicsWorldModel(Module):
|
|
|
2193
2228
|
video = cat((video, next_frame), dim = 2)
|
|
2194
2229
|
rewards = safe_cat((rewards, reward), dim = 1)
|
|
2195
2230
|
|
|
2231
|
+
# early break out if all terminated
|
|
2232
|
+
|
|
2233
|
+
if is_terminated.all():
|
|
2234
|
+
break
|
|
2235
|
+
|
|
2196
2236
|
# package up one experience for learning
|
|
2197
2237
|
|
|
2198
2238
|
batch, device = latents.shape[0], latents.device
|
|
@@ -2206,7 +2246,8 @@ class DynamicsWorldModel(Module):
|
|
|
2206
2246
|
values = values,
|
|
2207
2247
|
step_size = step_size,
|
|
2208
2248
|
agent_index = agent_index,
|
|
2209
|
-
|
|
2249
|
+
is_truncated = ~is_terminated,
|
|
2250
|
+
lens = episode_lens,
|
|
2210
2251
|
is_from_world_model = False
|
|
2211
2252
|
)
|
|
2212
2253
|
|
|
@@ -2267,11 +2308,41 @@ class DynamicsWorldModel(Module):
|
|
|
2267
2308
|
|
|
2268
2309
|
world_model_forward_context = torch.no_grad if only_learn_policy_value_heads else nullcontext
|
|
2269
2310
|
|
|
2311
|
+
# maybe keep track returns statistics and normalize returns and values before calculating advantage, as done in dreamer v3
|
|
2312
|
+
|
|
2313
|
+
if self.keep_reward_ema_stats:
|
|
2314
|
+
ema_returns_mean, ema_returns_var = self.ema_returns_mean, self.ema_returns_var
|
|
2315
|
+
|
|
2316
|
+
decay = 1. - self.reward_ema_decay
|
|
2317
|
+
|
|
2318
|
+
# quantile filter
|
|
2319
|
+
|
|
2320
|
+
lo, hi = torch.quantile(returns, self.reward_quantile_filter).tolist()
|
|
2321
|
+
returns_for_stats = returns.clamp(lo, hi)
|
|
2322
|
+
|
|
2323
|
+
# mean, var - todo - handle distributed
|
|
2324
|
+
|
|
2325
|
+
returns_mean, returns_var = returns.mean(), returns.var()
|
|
2326
|
+
|
|
2327
|
+
# ema
|
|
2328
|
+
|
|
2329
|
+
ema_returns_mean.lerp_(returns_mean, decay)
|
|
2330
|
+
ema_returns_var.lerp_(returns_var, decay)
|
|
2331
|
+
|
|
2332
|
+
# normalize
|
|
2333
|
+
|
|
2334
|
+
ema_returns_std = ema_returns_var.clamp(min = 1e-5).sqrt()
|
|
2335
|
+
|
|
2336
|
+
normed_returns = (returns - ema_returns_mean) / ema_returns_std
|
|
2337
|
+
normed_old_values = (old_values - ema_returns_mean) / ema_returns_std
|
|
2338
|
+
|
|
2339
|
+
advantage = normed_returns - normed_old_values
|
|
2340
|
+
else:
|
|
2341
|
+
advantage = returns - old_values
|
|
2342
|
+
|
|
2270
2343
|
# apparently they just use the sign of the advantage
|
|
2271
2344
|
# https://arxiv.org/abs/2410.04166v1
|
|
2272
2345
|
|
|
2273
|
-
advantage = returns - old_values
|
|
2274
|
-
|
|
2275
2346
|
if use_signed_advantage:
|
|
2276
2347
|
advantage = advantage.sign()
|
|
2277
2348
|
else:
|
dreamer4/mocks.py
CHANGED
|
@@ -7,6 +7,11 @@ from torch.nn import Module
|
|
|
7
7
|
|
|
8
8
|
from einops import repeat
|
|
9
9
|
|
|
10
|
+
# helpers
|
|
11
|
+
|
|
12
|
+
def exists(v):
|
|
13
|
+
return v is not None
|
|
14
|
+
|
|
10
15
|
# mock env
|
|
11
16
|
|
|
12
17
|
class MockEnv(Module):
|
|
@@ -15,7 +20,9 @@ class MockEnv(Module):
|
|
|
15
20
|
image_shape,
|
|
16
21
|
reward_range = (-100, 100),
|
|
17
22
|
num_envs = 1,
|
|
18
|
-
vectorized = False
|
|
23
|
+
vectorized = False,
|
|
24
|
+
terminate_after_step = None,
|
|
25
|
+
rand_terminate_prob = 0.05
|
|
19
26
|
):
|
|
20
27
|
super().__init__()
|
|
21
28
|
self.image_shape = image_shape
|
|
@@ -25,6 +32,12 @@ class MockEnv(Module):
|
|
|
25
32
|
self.vectorized = vectorized
|
|
26
33
|
assert not (vectorized and num_envs == 1)
|
|
27
34
|
|
|
35
|
+
# mocking termination
|
|
36
|
+
|
|
37
|
+
self.can_terminate = exists(terminate_after_step)
|
|
38
|
+
self.terminate_after_step = terminate_after_step
|
|
39
|
+
self.rand_terminate_prob = rand_terminate_prob
|
|
40
|
+
|
|
28
41
|
self.register_buffer('_step', tensor(0))
|
|
29
42
|
|
|
30
43
|
def get_random_state(self):
|
|
@@ -50,13 +63,23 @@ class MockEnv(Module):
|
|
|
50
63
|
|
|
51
64
|
reward = empty(()).uniform_(*self.reward_range)
|
|
52
65
|
|
|
53
|
-
if
|
|
54
|
-
|
|
66
|
+
if self.vectorized:
|
|
67
|
+
discrete, continuous = actions
|
|
68
|
+
assert discrete.shape[0] == self.num_envs, f'expected batch of actions for {self.num_envs} environments'
|
|
69
|
+
|
|
70
|
+
state = repeat(state, '... -> b ...', b = self.num_envs)
|
|
71
|
+
reward = repeat(reward, ' -> b', b = self.num_envs)
|
|
72
|
+
|
|
73
|
+
out = (state, reward)
|
|
74
|
+
|
|
75
|
+
if self.can_terminate:
|
|
76
|
+
terminate = (
|
|
77
|
+
(torch.rand((self.num_envs)) < self.rand_terminate_prob) &
|
|
78
|
+
(self._step > self.terminate_after_step)
|
|
79
|
+
)
|
|
55
80
|
|
|
56
|
-
|
|
57
|
-
assert discrete.shape[0] == self.num_envs, f'expected batch of actions for {self.num_envs} environments'
|
|
81
|
+
out = (*out, terminate)
|
|
58
82
|
|
|
59
|
-
|
|
60
|
-
reward = repeat(reward, ' -> b', b = self.num_envs)
|
|
83
|
+
self._step.add_(1)
|
|
61
84
|
|
|
62
|
-
return
|
|
85
|
+
return out
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
dreamer4/__init__.py,sha256=Jssh1obzDRtTfBLZl36kXge1cIQlMjf_8DyjPulvKSk,183
|
|
2
|
+
dreamer4/dreamer4.py,sha256=1WeHCCJO-wa85Ki9uP8FR70pFTro46LF4WyFMxSD90I,110556
|
|
3
|
+
dreamer4/mocks.py,sha256=S1kiENV3kHM1L6pBOLDqLVCZJg7ZydEmagiNb8sFIXc,2077
|
|
4
|
+
dreamer4/trainers.py,sha256=D2b7WTgTHElLhIWLFgl2Ct2knGJLTk91HHpC5UkNvG0,14028
|
|
5
|
+
dreamer4-0.0.82.dist-info/METADATA,sha256=by1KbrcKVy0epdLDLAiSN1MXR_FBgyD40Ol_Mn5iZNM,3065
|
|
6
|
+
dreamer4-0.0.82.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
+
dreamer4-0.0.82.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
+
dreamer4-0.0.82.dist-info/RECORD,,
|
dreamer4-0.0.80.dist-info/RECORD
DELETED
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
dreamer4/__init__.py,sha256=Jssh1obzDRtTfBLZl36kXge1cIQlMjf_8DyjPulvKSk,183
|
|
2
|
-
dreamer4/dreamer4.py,sha256=65MN3_ZqB3w5oXiWvpjRIHxxTwsACmnpw1HTd0GVJkU,108143
|
|
3
|
-
dreamer4/mocks.py,sha256=Oi91Yv1oK0E-Wz-KDkf79xoyWzIXCvMLCr0WYCpJDLA,1482
|
|
4
|
-
dreamer4/trainers.py,sha256=D2b7WTgTHElLhIWLFgl2Ct2knGJLTk91HHpC5UkNvG0,14028
|
|
5
|
-
dreamer4-0.0.80.dist-info/METADATA,sha256=kyaYELFKTxlaWrjMYHGFiY1LtuOKl1_fBC1GSd9eo2A,3065
|
|
6
|
-
dreamer4-0.0.80.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
-
dreamer4-0.0.80.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
-
dreamer4-0.0.80.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|