dreamer4 0.0.81__py3-none-any.whl → 0.0.83__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dreamer4 might be problematic. Click here for more details.
- dreamer4/dreamer4.py +62 -4
- dreamer4/mocks.py +43 -8
- {dreamer4-0.0.81.dist-info → dreamer4-0.0.83.dist-info}/METADATA +1 -1
- dreamer4-0.0.83.dist-info/RECORD +8 -0
- dreamer4-0.0.81.dist-info/RECORD +0 -8
- {dreamer4-0.0.81.dist-info → dreamer4-0.0.83.dist-info}/WHEEL +0 -0
- {dreamer4-0.0.81.dist-info → dreamer4-0.0.83.dist-info}/licenses/LICENSE +0 -0
dreamer4/dreamer4.py
CHANGED
|
@@ -1819,6 +1819,7 @@ class DynamicsWorldModel(Module):
|
|
|
1819
1819
|
num_residual_streams = 1,
|
|
1820
1820
|
keep_reward_ema_stats = False,
|
|
1821
1821
|
reward_ema_decay = 0.998,
|
|
1822
|
+
reward_quantile_filter = (0.05, 0.95),
|
|
1822
1823
|
gae_discount_factor = 0.997,
|
|
1823
1824
|
gae_lambda = 0.95,
|
|
1824
1825
|
ppo_eps_clip = 0.2,
|
|
@@ -2029,6 +2030,8 @@ class DynamicsWorldModel(Module):
|
|
|
2029
2030
|
self.keep_reward_ema_stats = keep_reward_ema_stats
|
|
2030
2031
|
self.reward_ema_decay = reward_ema_decay
|
|
2031
2032
|
|
|
2033
|
+
self.register_buffer('reward_quantile_filter', tensor(reward_quantile_filter), persistent = False)
|
|
2034
|
+
|
|
2032
2035
|
self.register_buffer('ema_returns_mean', tensor(0.))
|
|
2033
2036
|
self.register_buffer('ema_returns_var', tensor(1.))
|
|
2034
2037
|
|
|
@@ -2115,6 +2118,8 @@ class DynamicsWorldModel(Module):
|
|
|
2115
2118
|
else:
|
|
2116
2119
|
video = rearrange(init_frame, 'c vh vw -> 1 c 1 vh vw')
|
|
2117
2120
|
|
|
2121
|
+
batch, device = video.shape[0], video.device
|
|
2122
|
+
|
|
2118
2123
|
# accumulate
|
|
2119
2124
|
|
|
2120
2125
|
rewards = None
|
|
@@ -2125,11 +2130,21 @@ class DynamicsWorldModel(Module):
|
|
|
2125
2130
|
values = None
|
|
2126
2131
|
latents = None
|
|
2127
2132
|
|
|
2133
|
+
# keep track of termination, for setting the `is_truncated` field on Experience and for early stopping interaction with env
|
|
2134
|
+
|
|
2135
|
+
is_terminated = full((batch,), False, device = device)
|
|
2136
|
+
is_truncated = full((batch,), False, device = device)
|
|
2137
|
+
|
|
2138
|
+
episode_lens = full((batch,), 0, device = device)
|
|
2139
|
+
|
|
2128
2140
|
# maybe time kv cache
|
|
2129
2141
|
|
|
2130
2142
|
time_kv_cache = None
|
|
2131
2143
|
|
|
2132
|
-
|
|
2144
|
+
step_index = 0
|
|
2145
|
+
|
|
2146
|
+
while not is_terminated.all():
|
|
2147
|
+
step_index += 1
|
|
2133
2148
|
|
|
2134
2149
|
latents = self.video_tokenizer(video, return_latents = True)
|
|
2135
2150
|
|
|
@@ -2187,7 +2202,40 @@ class DynamicsWorldModel(Module):
|
|
|
2187
2202
|
|
|
2188
2203
|
# pass the sampled action to the environment and get back next state and reward
|
|
2189
2204
|
|
|
2190
|
-
|
|
2205
|
+
env_step_out = env.step((sampled_discrete_actions, sampled_continuous_actions))
|
|
2206
|
+
|
|
2207
|
+
if len(env_step_out) == 2:
|
|
2208
|
+
next_frame, reward = env_step_out
|
|
2209
|
+
terminated = full((batch,), False)
|
|
2210
|
+
truncated = full((batch,), False)
|
|
2211
|
+
|
|
2212
|
+
elif len(env_step_out) == 3:
|
|
2213
|
+
next_frame, reward, terminated = env_step_out
|
|
2214
|
+
truncated = full((batch,), False)
|
|
2215
|
+
|
|
2216
|
+
elif len(env_step_out) == 4:
|
|
2217
|
+
next_frame, reward, terminated, truncated = env_step_out
|
|
2218
|
+
|
|
2219
|
+
# update episode lens
|
|
2220
|
+
|
|
2221
|
+
episode_lens = torch.where(is_terminated, episode_lens, episode_lens + 1)
|
|
2222
|
+
|
|
2223
|
+
# update `is_terminated`
|
|
2224
|
+
|
|
2225
|
+
# (1) - environment says it is terminated
|
|
2226
|
+
# (2) - previous step is truncated (this step is for bootstrap value)
|
|
2227
|
+
|
|
2228
|
+
is_terminated |= (terminated | is_truncated)
|
|
2229
|
+
|
|
2230
|
+
# update `is_truncated`
|
|
2231
|
+
|
|
2232
|
+
if step_index <= max_timesteps:
|
|
2233
|
+
is_truncated |= truncated
|
|
2234
|
+
|
|
2235
|
+
if step_index == max_timesteps:
|
|
2236
|
+
# if the step index is at the max time step allowed, set the truncated flag, if not already terminated
|
|
2237
|
+
|
|
2238
|
+
is_truncated |= ~is_terminated
|
|
2191
2239
|
|
|
2192
2240
|
# batch and time dimension
|
|
2193
2241
|
|
|
@@ -2216,7 +2264,8 @@ class DynamicsWorldModel(Module):
|
|
|
2216
2264
|
values = values,
|
|
2217
2265
|
step_size = step_size,
|
|
2218
2266
|
agent_index = agent_index,
|
|
2219
|
-
|
|
2267
|
+
is_truncated = is_truncated,
|
|
2268
|
+
lens = episode_lens,
|
|
2220
2269
|
is_from_world_model = False
|
|
2221
2270
|
)
|
|
2222
2271
|
|
|
@@ -2284,13 +2333,22 @@ class DynamicsWorldModel(Module):
|
|
|
2284
2333
|
|
|
2285
2334
|
decay = 1. - self.reward_ema_decay
|
|
2286
2335
|
|
|
2287
|
-
#
|
|
2336
|
+
# quantile filter
|
|
2337
|
+
|
|
2338
|
+
lo, hi = torch.quantile(returns, self.reward_quantile_filter).tolist()
|
|
2339
|
+
returns_for_stats = returns.clamp(lo, hi)
|
|
2340
|
+
|
|
2341
|
+
# mean, var - todo - handle distributed
|
|
2288
2342
|
|
|
2289
2343
|
returns_mean, returns_var = returns.mean(), returns.var()
|
|
2290
2344
|
|
|
2345
|
+
# ema
|
|
2346
|
+
|
|
2291
2347
|
ema_returns_mean.lerp_(returns_mean, decay)
|
|
2292
2348
|
ema_returns_var.lerp_(returns_var, decay)
|
|
2293
2349
|
|
|
2350
|
+
# normalize
|
|
2351
|
+
|
|
2294
2352
|
ema_returns_std = ema_returns_var.clamp(min = 1e-5).sqrt()
|
|
2295
2353
|
|
|
2296
2354
|
normed_returns = (returns - ema_returns_mean) / ema_returns_std
|
dreamer4/mocks.py
CHANGED
|
@@ -7,6 +7,11 @@ from torch.nn import Module
|
|
|
7
7
|
|
|
8
8
|
from einops import repeat
|
|
9
9
|
|
|
10
|
+
# helpers
|
|
11
|
+
|
|
12
|
+
def exists(v):
|
|
13
|
+
return v is not None
|
|
14
|
+
|
|
10
15
|
# mock env
|
|
11
16
|
|
|
12
17
|
class MockEnv(Module):
|
|
@@ -15,7 +20,11 @@ class MockEnv(Module):
|
|
|
15
20
|
image_shape,
|
|
16
21
|
reward_range = (-100, 100),
|
|
17
22
|
num_envs = 1,
|
|
18
|
-
vectorized = False
|
|
23
|
+
vectorized = False,
|
|
24
|
+
terminate_after_step = None,
|
|
25
|
+
rand_terminate_prob = 0.05,
|
|
26
|
+
can_truncate = False,
|
|
27
|
+
rand_truncate_prob = 0.05,
|
|
19
28
|
):
|
|
20
29
|
super().__init__()
|
|
21
30
|
self.image_shape = image_shape
|
|
@@ -25,6 +34,15 @@ class MockEnv(Module):
|
|
|
25
34
|
self.vectorized = vectorized
|
|
26
35
|
assert not (vectorized and num_envs == 1)
|
|
27
36
|
|
|
37
|
+
# mocking termination and truncation
|
|
38
|
+
|
|
39
|
+
self.can_terminate = exists(terminate_after_step)
|
|
40
|
+
self.terminate_after_step = terminate_after_step
|
|
41
|
+
self.rand_terminate_prob = rand_terminate_prob
|
|
42
|
+
|
|
43
|
+
self.can_truncate = can_truncate
|
|
44
|
+
self.rand_truncate_prob = rand_truncate_prob
|
|
45
|
+
|
|
28
46
|
self.register_buffer('_step', tensor(0))
|
|
29
47
|
|
|
30
48
|
def get_random_state(self):
|
|
@@ -50,13 +68,30 @@ class MockEnv(Module):
|
|
|
50
68
|
|
|
51
69
|
reward = empty(()).uniform_(*self.reward_range)
|
|
52
70
|
|
|
53
|
-
if
|
|
54
|
-
|
|
71
|
+
if self.vectorized:
|
|
72
|
+
discrete, continuous = actions
|
|
73
|
+
assert discrete.shape[0] == self.num_envs, f'expected batch of actions for {self.num_envs} environments'
|
|
74
|
+
|
|
75
|
+
state = repeat(state, '... -> b ...', b = self.num_envs)
|
|
76
|
+
reward = repeat(reward, ' -> b', b = self.num_envs)
|
|
77
|
+
|
|
78
|
+
out = (state, reward)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
if self.can_terminate:
|
|
82
|
+
shape = (self.num_envs,) if self.vectorized else (1,)
|
|
83
|
+
valid_step = self._step > self.terminate_after_step
|
|
84
|
+
|
|
85
|
+
terminate = (torch.rand(shape) < self.rand_terminate_prob) & valid_step
|
|
86
|
+
|
|
87
|
+
out = (*out, terminate)
|
|
88
|
+
|
|
89
|
+
# maybe truncation
|
|
55
90
|
|
|
56
|
-
|
|
57
|
-
|
|
91
|
+
if self.can_truncate:
|
|
92
|
+
truncate = (torch.rand(shape) < self.rand_truncate_prob) & valid_step & ~terminate
|
|
93
|
+
out = (*out, truncate)
|
|
58
94
|
|
|
59
|
-
|
|
60
|
-
reward = repeat(reward, ' -> b', b = self.num_envs)
|
|
95
|
+
self._step.add_(1)
|
|
61
96
|
|
|
62
|
-
return
|
|
97
|
+
return out
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
dreamer4/__init__.py,sha256=Jssh1obzDRtTfBLZl36kXge1cIQlMjf_8DyjPulvKSk,183
|
|
2
|
+
dreamer4/dreamer4.py,sha256=t9g-fJVRPn4nefpNiKwqkIkKZtIeJAn5V4ruNCJ5a9A,111265
|
|
3
|
+
dreamer4/mocks.py,sha256=TfqOB_Gq6N_GggBYwa6ZAJQx38ntlYbXZe23Ne4jshw,2502
|
|
4
|
+
dreamer4/trainers.py,sha256=D2b7WTgTHElLhIWLFgl2Ct2knGJLTk91HHpC5UkNvG0,14028
|
|
5
|
+
dreamer4-0.0.83.dist-info/METADATA,sha256=pYIw5Pj41JhfJ4Hp9AfegXrxnhIZ9Fk98kF7lY3nAZk,3065
|
|
6
|
+
dreamer4-0.0.83.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
+
dreamer4-0.0.83.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
+
dreamer4-0.0.83.dist-info/RECORD,,
|
dreamer4-0.0.81.dist-info/RECORD
DELETED
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
dreamer4/__init__.py,sha256=Jssh1obzDRtTfBLZl36kXge1cIQlMjf_8DyjPulvKSk,183
|
|
2
|
-
dreamer4/dreamer4.py,sha256=8IxpSzfpm_PLqwJ1YReVRuLuvbvQJNrhIKCvqK35jdQ,109317
|
|
3
|
-
dreamer4/mocks.py,sha256=Oi91Yv1oK0E-Wz-KDkf79xoyWzIXCvMLCr0WYCpJDLA,1482
|
|
4
|
-
dreamer4/trainers.py,sha256=D2b7WTgTHElLhIWLFgl2Ct2knGJLTk91HHpC5UkNvG0,14028
|
|
5
|
-
dreamer4-0.0.81.dist-info/METADATA,sha256=8qDj5Pw3N_hqXOrGFCnP6bgbTqB7B2dn5_9FRn4mExw,3065
|
|
6
|
-
dreamer4-0.0.81.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
-
dreamer4-0.0.81.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
-
dreamer4-0.0.81.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|