dreamer4 0.0.81__py3-none-any.whl → 0.0.82__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dreamer4 might be problematic. Click here for more details.
- dreamer4/dreamer4.py +43 -3
- dreamer4/mocks.py +31 -8
- {dreamer4-0.0.81.dist-info → dreamer4-0.0.82.dist-info}/METADATA +1 -1
- dreamer4-0.0.82.dist-info/RECORD +8 -0
- dreamer4-0.0.81.dist-info/RECORD +0 -8
- {dreamer4-0.0.81.dist-info → dreamer4-0.0.82.dist-info}/WHEEL +0 -0
- {dreamer4-0.0.81.dist-info → dreamer4-0.0.82.dist-info}/licenses/LICENSE +0 -0
dreamer4/dreamer4.py
CHANGED
|
@@ -1819,6 +1819,7 @@ class DynamicsWorldModel(Module):
|
|
|
1819
1819
|
num_residual_streams = 1,
|
|
1820
1820
|
keep_reward_ema_stats = False,
|
|
1821
1821
|
reward_ema_decay = 0.998,
|
|
1822
|
+
reward_quantile_filter = (0.05, 0.95),
|
|
1822
1823
|
gae_discount_factor = 0.997,
|
|
1823
1824
|
gae_lambda = 0.95,
|
|
1824
1825
|
ppo_eps_clip = 0.2,
|
|
@@ -2029,6 +2030,8 @@ class DynamicsWorldModel(Module):
|
|
|
2029
2030
|
self.keep_reward_ema_stats = keep_reward_ema_stats
|
|
2030
2031
|
self.reward_ema_decay = reward_ema_decay
|
|
2031
2032
|
|
|
2033
|
+
self.register_buffer('reward_quantile_filter', tensor(reward_quantile_filter), persistent = False)
|
|
2034
|
+
|
|
2032
2035
|
self.register_buffer('ema_returns_mean', tensor(0.))
|
|
2033
2036
|
self.register_buffer('ema_returns_var', tensor(1.))
|
|
2034
2037
|
|
|
@@ -2115,6 +2118,8 @@ class DynamicsWorldModel(Module):
|
|
|
2115
2118
|
else:
|
|
2116
2119
|
video = rearrange(init_frame, 'c vh vw -> 1 c 1 vh vw')
|
|
2117
2120
|
|
|
2121
|
+
batch, device = video.shape[0], video.device
|
|
2122
|
+
|
|
2118
2123
|
# accumulate
|
|
2119
2124
|
|
|
2120
2125
|
rewards = None
|
|
@@ -2125,6 +2130,11 @@ class DynamicsWorldModel(Module):
|
|
|
2125
2130
|
values = None
|
|
2126
2131
|
latents = None
|
|
2127
2132
|
|
|
2133
|
+
# keep track of termination, for setting the `is_truncated` field on Experience and for early stopping interaction with env
|
|
2134
|
+
|
|
2135
|
+
is_terminated = full((batch,), False, device = device)
|
|
2136
|
+
episode_lens = full((batch,), 0, device = device)
|
|
2137
|
+
|
|
2128
2138
|
# maybe time kv cache
|
|
2129
2139
|
|
|
2130
2140
|
time_kv_cache = None
|
|
@@ -2187,7 +2197,22 @@ class DynamicsWorldModel(Module):
|
|
|
2187
2197
|
|
|
2188
2198
|
# pass the sampled action to the environment and get back next state and reward
|
|
2189
2199
|
|
|
2190
|
-
|
|
2200
|
+
env_step_out = env.step((sampled_discrete_actions, sampled_continuous_actions))
|
|
2201
|
+
|
|
2202
|
+
if len(env_step_out) == 2:
|
|
2203
|
+
next_frame, reward = env_step_out
|
|
2204
|
+
terminate = full((batch,), False)
|
|
2205
|
+
|
|
2206
|
+
elif len(env_step_out) == 3:
|
|
2207
|
+
next_frame, reward, terminate = env_step_out
|
|
2208
|
+
|
|
2209
|
+
# update episode lens
|
|
2210
|
+
|
|
2211
|
+
episode_lens = torch.where(is_terminated, episode_lens, episode_lens + 1)
|
|
2212
|
+
|
|
2213
|
+
# update `is_terminated`
|
|
2214
|
+
|
|
2215
|
+
is_terminated |= terminate
|
|
2191
2216
|
|
|
2192
2217
|
# batch and time dimension
|
|
2193
2218
|
|
|
@@ -2203,6 +2228,11 @@ class DynamicsWorldModel(Module):
|
|
|
2203
2228
|
video = cat((video, next_frame), dim = 2)
|
|
2204
2229
|
rewards = safe_cat((rewards, reward), dim = 1)
|
|
2205
2230
|
|
|
2231
|
+
# early break out if all terminated
|
|
2232
|
+
|
|
2233
|
+
if is_terminated.all():
|
|
2234
|
+
break
|
|
2235
|
+
|
|
2206
2236
|
# package up one experience for learning
|
|
2207
2237
|
|
|
2208
2238
|
batch, device = latents.shape[0], latents.device
|
|
@@ -2216,7 +2246,8 @@ class DynamicsWorldModel(Module):
|
|
|
2216
2246
|
values = values,
|
|
2217
2247
|
step_size = step_size,
|
|
2218
2248
|
agent_index = agent_index,
|
|
2219
|
-
|
|
2249
|
+
is_truncated = ~is_terminated,
|
|
2250
|
+
lens = episode_lens,
|
|
2220
2251
|
is_from_world_model = False
|
|
2221
2252
|
)
|
|
2222
2253
|
|
|
@@ -2284,13 +2315,22 @@ class DynamicsWorldModel(Module):
|
|
|
2284
2315
|
|
|
2285
2316
|
decay = 1. - self.reward_ema_decay
|
|
2286
2317
|
|
|
2287
|
-
#
|
|
2318
|
+
# quantile filter
|
|
2319
|
+
|
|
2320
|
+
lo, hi = torch.quantile(returns, self.reward_quantile_filter).tolist()
|
|
2321
|
+
returns_for_stats = returns.clamp(lo, hi)
|
|
2322
|
+
|
|
2323
|
+
# mean, var - todo - handle distributed
|
|
2288
2324
|
|
|
2289
2325
|
returns_mean, returns_var = returns.mean(), returns.var()
|
|
2290
2326
|
|
|
2327
|
+
# ema
|
|
2328
|
+
|
|
2291
2329
|
ema_returns_mean.lerp_(returns_mean, decay)
|
|
2292
2330
|
ema_returns_var.lerp_(returns_var, decay)
|
|
2293
2331
|
|
|
2332
|
+
# normalize
|
|
2333
|
+
|
|
2294
2334
|
ema_returns_std = ema_returns_var.clamp(min = 1e-5).sqrt()
|
|
2295
2335
|
|
|
2296
2336
|
normed_returns = (returns - ema_returns_mean) / ema_returns_std
|
dreamer4/mocks.py
CHANGED
|
@@ -7,6 +7,11 @@ from torch.nn import Module
|
|
|
7
7
|
|
|
8
8
|
from einops import repeat
|
|
9
9
|
|
|
10
|
+
# helpers
|
|
11
|
+
|
|
12
|
+
def exists(v):
|
|
13
|
+
return v is not None
|
|
14
|
+
|
|
10
15
|
# mock env
|
|
11
16
|
|
|
12
17
|
class MockEnv(Module):
|
|
@@ -15,7 +20,9 @@ class MockEnv(Module):
|
|
|
15
20
|
image_shape,
|
|
16
21
|
reward_range = (-100, 100),
|
|
17
22
|
num_envs = 1,
|
|
18
|
-
vectorized = False
|
|
23
|
+
vectorized = False,
|
|
24
|
+
terminate_after_step = None,
|
|
25
|
+
rand_terminate_prob = 0.05
|
|
19
26
|
):
|
|
20
27
|
super().__init__()
|
|
21
28
|
self.image_shape = image_shape
|
|
@@ -25,6 +32,12 @@ class MockEnv(Module):
|
|
|
25
32
|
self.vectorized = vectorized
|
|
26
33
|
assert not (vectorized and num_envs == 1)
|
|
27
34
|
|
|
35
|
+
# mocking termination
|
|
36
|
+
|
|
37
|
+
self.can_terminate = exists(terminate_after_step)
|
|
38
|
+
self.terminate_after_step = terminate_after_step
|
|
39
|
+
self.rand_terminate_prob = rand_terminate_prob
|
|
40
|
+
|
|
28
41
|
self.register_buffer('_step', tensor(0))
|
|
29
42
|
|
|
30
43
|
def get_random_state(self):
|
|
@@ -50,13 +63,23 @@ class MockEnv(Module):
|
|
|
50
63
|
|
|
51
64
|
reward = empty(()).uniform_(*self.reward_range)
|
|
52
65
|
|
|
53
|
-
if
|
|
54
|
-
|
|
66
|
+
if self.vectorized:
|
|
67
|
+
discrete, continuous = actions
|
|
68
|
+
assert discrete.shape[0] == self.num_envs, f'expected batch of actions for {self.num_envs} environments'
|
|
69
|
+
|
|
70
|
+
state = repeat(state, '... -> b ...', b = self.num_envs)
|
|
71
|
+
reward = repeat(reward, ' -> b', b = self.num_envs)
|
|
72
|
+
|
|
73
|
+
out = (state, reward)
|
|
74
|
+
|
|
75
|
+
if self.can_terminate:
|
|
76
|
+
terminate = (
|
|
77
|
+
(torch.rand((self.num_envs)) < self.rand_terminate_prob) &
|
|
78
|
+
(self._step > self.terminate_after_step)
|
|
79
|
+
)
|
|
55
80
|
|
|
56
|
-
|
|
57
|
-
assert discrete.shape[0] == self.num_envs, f'expected batch of actions for {self.num_envs} environments'
|
|
81
|
+
out = (*out, terminate)
|
|
58
82
|
|
|
59
|
-
|
|
60
|
-
reward = repeat(reward, ' -> b', b = self.num_envs)
|
|
83
|
+
self._step.add_(1)
|
|
61
84
|
|
|
62
|
-
return
|
|
85
|
+
return out
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
dreamer4/__init__.py,sha256=Jssh1obzDRtTfBLZl36kXge1cIQlMjf_8DyjPulvKSk,183
|
|
2
|
+
dreamer4/dreamer4.py,sha256=1WeHCCJO-wa85Ki9uP8FR70pFTro46LF4WyFMxSD90I,110556
|
|
3
|
+
dreamer4/mocks.py,sha256=S1kiENV3kHM1L6pBOLDqLVCZJg7ZydEmagiNb8sFIXc,2077
|
|
4
|
+
dreamer4/trainers.py,sha256=D2b7WTgTHElLhIWLFgl2Ct2knGJLTk91HHpC5UkNvG0,14028
|
|
5
|
+
dreamer4-0.0.82.dist-info/METADATA,sha256=by1KbrcKVy0epdLDLAiSN1MXR_FBgyD40Ol_Mn5iZNM,3065
|
|
6
|
+
dreamer4-0.0.82.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
+
dreamer4-0.0.82.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
+
dreamer4-0.0.82.dist-info/RECORD,,
|
dreamer4-0.0.81.dist-info/RECORD
DELETED
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
dreamer4/__init__.py,sha256=Jssh1obzDRtTfBLZl36kXge1cIQlMjf_8DyjPulvKSk,183
|
|
2
|
-
dreamer4/dreamer4.py,sha256=8IxpSzfpm_PLqwJ1YReVRuLuvbvQJNrhIKCvqK35jdQ,109317
|
|
3
|
-
dreamer4/mocks.py,sha256=Oi91Yv1oK0E-Wz-KDkf79xoyWzIXCvMLCr0WYCpJDLA,1482
|
|
4
|
-
dreamer4/trainers.py,sha256=D2b7WTgTHElLhIWLFgl2Ct2knGJLTk91HHpC5UkNvG0,14028
|
|
5
|
-
dreamer4-0.0.81.dist-info/METADATA,sha256=8qDj5Pw3N_hqXOrGFCnP6bgbTqB7B2dn5_9FRn4mExw,3065
|
|
6
|
-
dreamer4-0.0.81.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
-
dreamer4-0.0.81.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
-
dreamer4-0.0.81.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|