dreamer4 0.0.81__tar.gz → 0.0.83__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dreamer4 might be problematic. Click here for more details.
- {dreamer4-0.0.81 → dreamer4-0.0.83}/PKG-INFO +1 -1
- {dreamer4-0.0.81 → dreamer4-0.0.83}/dreamer4/dreamer4.py +62 -4
- dreamer4-0.0.83/dreamer4/mocks.py +97 -0
- {dreamer4-0.0.81 → dreamer4-0.0.83}/pyproject.toml +1 -1
- {dreamer4-0.0.81 → dreamer4-0.0.83}/tests/test_dreamer.py +13 -2
- dreamer4-0.0.81/dreamer4/mocks.py +0 -62
- {dreamer4-0.0.81 → dreamer4-0.0.83}/.github/workflows/python-publish.yml +0 -0
- {dreamer4-0.0.81 → dreamer4-0.0.83}/.github/workflows/test.yml +0 -0
- {dreamer4-0.0.81 → dreamer4-0.0.83}/.gitignore +0 -0
- {dreamer4-0.0.81 → dreamer4-0.0.83}/LICENSE +0 -0
- {dreamer4-0.0.81 → dreamer4-0.0.83}/README.md +0 -0
- {dreamer4-0.0.81 → dreamer4-0.0.83}/dreamer4/__init__.py +0 -0
- {dreamer4-0.0.81 → dreamer4-0.0.83}/dreamer4/trainers.py +0 -0
- {dreamer4-0.0.81 → dreamer4-0.0.83}/dreamer4-fig2.png +0 -0
|
@@ -1819,6 +1819,7 @@ class DynamicsWorldModel(Module):
|
|
|
1819
1819
|
num_residual_streams = 1,
|
|
1820
1820
|
keep_reward_ema_stats = False,
|
|
1821
1821
|
reward_ema_decay = 0.998,
|
|
1822
|
+
reward_quantile_filter = (0.05, 0.95),
|
|
1822
1823
|
gae_discount_factor = 0.997,
|
|
1823
1824
|
gae_lambda = 0.95,
|
|
1824
1825
|
ppo_eps_clip = 0.2,
|
|
@@ -2029,6 +2030,8 @@ class DynamicsWorldModel(Module):
|
|
|
2029
2030
|
self.keep_reward_ema_stats = keep_reward_ema_stats
|
|
2030
2031
|
self.reward_ema_decay = reward_ema_decay
|
|
2031
2032
|
|
|
2033
|
+
self.register_buffer('reward_quantile_filter', tensor(reward_quantile_filter), persistent = False)
|
|
2034
|
+
|
|
2032
2035
|
self.register_buffer('ema_returns_mean', tensor(0.))
|
|
2033
2036
|
self.register_buffer('ema_returns_var', tensor(1.))
|
|
2034
2037
|
|
|
@@ -2115,6 +2118,8 @@ class DynamicsWorldModel(Module):
|
|
|
2115
2118
|
else:
|
|
2116
2119
|
video = rearrange(init_frame, 'c vh vw -> 1 c 1 vh vw')
|
|
2117
2120
|
|
|
2121
|
+
batch, device = video.shape[0], video.device
|
|
2122
|
+
|
|
2118
2123
|
# accumulate
|
|
2119
2124
|
|
|
2120
2125
|
rewards = None
|
|
@@ -2125,11 +2130,21 @@ class DynamicsWorldModel(Module):
|
|
|
2125
2130
|
values = None
|
|
2126
2131
|
latents = None
|
|
2127
2132
|
|
|
2133
|
+
# keep track of termination, for setting the `is_truncated` field on Experience and for early stopping interaction with env
|
|
2134
|
+
|
|
2135
|
+
is_terminated = full((batch,), False, device = device)
|
|
2136
|
+
is_truncated = full((batch,), False, device = device)
|
|
2137
|
+
|
|
2138
|
+
episode_lens = full((batch,), 0, device = device)
|
|
2139
|
+
|
|
2128
2140
|
# maybe time kv cache
|
|
2129
2141
|
|
|
2130
2142
|
time_kv_cache = None
|
|
2131
2143
|
|
|
2132
|
-
|
|
2144
|
+
step_index = 0
|
|
2145
|
+
|
|
2146
|
+
while not is_terminated.all():
|
|
2147
|
+
step_index += 1
|
|
2133
2148
|
|
|
2134
2149
|
latents = self.video_tokenizer(video, return_latents = True)
|
|
2135
2150
|
|
|
@@ -2187,7 +2202,40 @@ class DynamicsWorldModel(Module):
|
|
|
2187
2202
|
|
|
2188
2203
|
# pass the sampled action to the environment and get back next state and reward
|
|
2189
2204
|
|
|
2190
|
-
|
|
2205
|
+
env_step_out = env.step((sampled_discrete_actions, sampled_continuous_actions))
|
|
2206
|
+
|
|
2207
|
+
if len(env_step_out) == 2:
|
|
2208
|
+
next_frame, reward = env_step_out
|
|
2209
|
+
terminated = full((batch,), False)
|
|
2210
|
+
truncated = full((batch,), False)
|
|
2211
|
+
|
|
2212
|
+
elif len(env_step_out) == 3:
|
|
2213
|
+
next_frame, reward, terminated = env_step_out
|
|
2214
|
+
truncated = full((batch,), False)
|
|
2215
|
+
|
|
2216
|
+
elif len(env_step_out) == 4:
|
|
2217
|
+
next_frame, reward, terminated, truncated = env_step_out
|
|
2218
|
+
|
|
2219
|
+
# update episode lens
|
|
2220
|
+
|
|
2221
|
+
episode_lens = torch.where(is_terminated, episode_lens, episode_lens + 1)
|
|
2222
|
+
|
|
2223
|
+
# update `is_terminated`
|
|
2224
|
+
|
|
2225
|
+
# (1) - environment says it is terminated
|
|
2226
|
+
# (2) - previous step is truncated (this step is for bootstrap value)
|
|
2227
|
+
|
|
2228
|
+
is_terminated |= (terminated | is_truncated)
|
|
2229
|
+
|
|
2230
|
+
# update `is_truncated`
|
|
2231
|
+
|
|
2232
|
+
if step_index <= max_timesteps:
|
|
2233
|
+
is_truncated |= truncated
|
|
2234
|
+
|
|
2235
|
+
if step_index == max_timesteps:
|
|
2236
|
+
# if the step index is at the max time step allowed, set the truncated flag, if not already terminated
|
|
2237
|
+
|
|
2238
|
+
is_truncated |= ~is_terminated
|
|
2191
2239
|
|
|
2192
2240
|
# batch and time dimension
|
|
2193
2241
|
|
|
@@ -2216,7 +2264,8 @@ class DynamicsWorldModel(Module):
|
|
|
2216
2264
|
values = values,
|
|
2217
2265
|
step_size = step_size,
|
|
2218
2266
|
agent_index = agent_index,
|
|
2219
|
-
|
|
2267
|
+
is_truncated = is_truncated,
|
|
2268
|
+
lens = episode_lens,
|
|
2220
2269
|
is_from_world_model = False
|
|
2221
2270
|
)
|
|
2222
2271
|
|
|
@@ -2284,13 +2333,22 @@ class DynamicsWorldModel(Module):
|
|
|
2284
2333
|
|
|
2285
2334
|
decay = 1. - self.reward_ema_decay
|
|
2286
2335
|
|
|
2287
|
-
#
|
|
2336
|
+
# quantile filter
|
|
2337
|
+
|
|
2338
|
+
lo, hi = torch.quantile(returns, self.reward_quantile_filter).tolist()
|
|
2339
|
+
returns_for_stats = returns.clamp(lo, hi)
|
|
2340
|
+
|
|
2341
|
+
# mean, var - todo - handle distributed
|
|
2288
2342
|
|
|
2289
2343
|
returns_mean, returns_var = returns.mean(), returns.var()
|
|
2290
2344
|
|
|
2345
|
+
# ema
|
|
2346
|
+
|
|
2291
2347
|
ema_returns_mean.lerp_(returns_mean, decay)
|
|
2292
2348
|
ema_returns_var.lerp_(returns_var, decay)
|
|
2293
2349
|
|
|
2350
|
+
# normalize
|
|
2351
|
+
|
|
2294
2352
|
ema_returns_std = ema_returns_var.clamp(min = 1e-5).sqrt()
|
|
2295
2353
|
|
|
2296
2354
|
normed_returns = (returns - ema_returns_mean) / ema_returns_std
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from random import choice
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from torch import tensor, empty, randn, randint
|
|
6
|
+
from torch.nn import Module
|
|
7
|
+
|
|
8
|
+
from einops import repeat
|
|
9
|
+
|
|
10
|
+
# helpers
|
|
11
|
+
|
|
12
|
+
def exists(v):
|
|
13
|
+
return v is not None
|
|
14
|
+
|
|
15
|
+
# mock env
|
|
16
|
+
|
|
17
|
+
class MockEnv(Module):
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
image_shape,
|
|
21
|
+
reward_range = (-100, 100),
|
|
22
|
+
num_envs = 1,
|
|
23
|
+
vectorized = False,
|
|
24
|
+
terminate_after_step = None,
|
|
25
|
+
rand_terminate_prob = 0.05,
|
|
26
|
+
can_truncate = False,
|
|
27
|
+
rand_truncate_prob = 0.05,
|
|
28
|
+
):
|
|
29
|
+
super().__init__()
|
|
30
|
+
self.image_shape = image_shape
|
|
31
|
+
self.reward_range = reward_range
|
|
32
|
+
|
|
33
|
+
self.num_envs = num_envs
|
|
34
|
+
self.vectorized = vectorized
|
|
35
|
+
assert not (vectorized and num_envs == 1)
|
|
36
|
+
|
|
37
|
+
# mocking termination and truncation
|
|
38
|
+
|
|
39
|
+
self.can_terminate = exists(terminate_after_step)
|
|
40
|
+
self.terminate_after_step = terminate_after_step
|
|
41
|
+
self.rand_terminate_prob = rand_terminate_prob
|
|
42
|
+
|
|
43
|
+
self.can_truncate = can_truncate
|
|
44
|
+
self.rand_truncate_prob = rand_truncate_prob
|
|
45
|
+
|
|
46
|
+
self.register_buffer('_step', tensor(0))
|
|
47
|
+
|
|
48
|
+
def get_random_state(self):
|
|
49
|
+
return randn(3, *self.image_shape)
|
|
50
|
+
|
|
51
|
+
def reset(
|
|
52
|
+
self,
|
|
53
|
+
seed = None
|
|
54
|
+
):
|
|
55
|
+
self._step.zero_()
|
|
56
|
+
state = self.get_random_state()
|
|
57
|
+
|
|
58
|
+
if self.vectorized:
|
|
59
|
+
state = repeat(state, '... -> b ...', b = self.num_envs)
|
|
60
|
+
|
|
61
|
+
return state
|
|
62
|
+
|
|
63
|
+
def step(
|
|
64
|
+
self,
|
|
65
|
+
actions,
|
|
66
|
+
):
|
|
67
|
+
state = self.get_random_state()
|
|
68
|
+
|
|
69
|
+
reward = empty(()).uniform_(*self.reward_range)
|
|
70
|
+
|
|
71
|
+
if self.vectorized:
|
|
72
|
+
discrete, continuous = actions
|
|
73
|
+
assert discrete.shape[0] == self.num_envs, f'expected batch of actions for {self.num_envs} environments'
|
|
74
|
+
|
|
75
|
+
state = repeat(state, '... -> b ...', b = self.num_envs)
|
|
76
|
+
reward = repeat(reward, ' -> b', b = self.num_envs)
|
|
77
|
+
|
|
78
|
+
out = (state, reward)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
if self.can_terminate:
|
|
82
|
+
shape = (self.num_envs,) if self.vectorized else (1,)
|
|
83
|
+
valid_step = self._step > self.terminate_after_step
|
|
84
|
+
|
|
85
|
+
terminate = (torch.rand(shape) < self.rand_terminate_prob) & valid_step
|
|
86
|
+
|
|
87
|
+
out = (*out, terminate)
|
|
88
|
+
|
|
89
|
+
# maybe truncation
|
|
90
|
+
|
|
91
|
+
if self.can_truncate:
|
|
92
|
+
truncate = (torch.rand(shape) < self.rand_truncate_prob) & valid_step & ~terminate
|
|
93
|
+
out = (*out, truncate)
|
|
94
|
+
|
|
95
|
+
self._step.add_(1)
|
|
96
|
+
|
|
97
|
+
return out
|
|
@@ -612,9 +612,13 @@ def test_cache_generate():
|
|
|
612
612
|
|
|
613
613
|
@param('vectorized', (False, True))
|
|
614
614
|
@param('use_signed_advantage', (False, True))
|
|
615
|
+
@param('env_can_terminate', (False, True))
|
|
616
|
+
@param('env_can_truncate', (False, True))
|
|
615
617
|
def test_online_rl(
|
|
616
618
|
vectorized,
|
|
617
|
-
use_signed_advantage
|
|
619
|
+
use_signed_advantage,
|
|
620
|
+
env_can_terminate,
|
|
621
|
+
env_can_truncate
|
|
618
622
|
):
|
|
619
623
|
from dreamer4.dreamer4 import DynamicsWorldModel, VideoTokenizer
|
|
620
624
|
|
|
@@ -649,7 +653,14 @@ def test_online_rl(
|
|
|
649
653
|
from dreamer4.mocks import MockEnv
|
|
650
654
|
from dreamer4.dreamer4 import combine_experiences
|
|
651
655
|
|
|
652
|
-
mock_env = MockEnv(
|
|
656
|
+
mock_env = MockEnv(
|
|
657
|
+
(256, 256),
|
|
658
|
+
vectorized = vectorized,
|
|
659
|
+
num_envs = 4,
|
|
660
|
+
terminate_after_step = 2 if env_can_terminate else None,
|
|
661
|
+
can_truncate = env_can_truncate,
|
|
662
|
+
rand_terminate_prob = 0.1
|
|
663
|
+
)
|
|
653
664
|
|
|
654
665
|
# manually
|
|
655
666
|
|
|
@@ -1,62 +0,0 @@
|
|
|
1
|
-
from __future__ import annotations
|
|
2
|
-
from random import choice
|
|
3
|
-
|
|
4
|
-
import torch
|
|
5
|
-
from torch import tensor, empty, randn, randint
|
|
6
|
-
from torch.nn import Module
|
|
7
|
-
|
|
8
|
-
from einops import repeat
|
|
9
|
-
|
|
10
|
-
# mock env
|
|
11
|
-
|
|
12
|
-
class MockEnv(Module):
|
|
13
|
-
def __init__(
|
|
14
|
-
self,
|
|
15
|
-
image_shape,
|
|
16
|
-
reward_range = (-100, 100),
|
|
17
|
-
num_envs = 1,
|
|
18
|
-
vectorized = False
|
|
19
|
-
):
|
|
20
|
-
super().__init__()
|
|
21
|
-
self.image_shape = image_shape
|
|
22
|
-
self.reward_range = reward_range
|
|
23
|
-
|
|
24
|
-
self.num_envs = num_envs
|
|
25
|
-
self.vectorized = vectorized
|
|
26
|
-
assert not (vectorized and num_envs == 1)
|
|
27
|
-
|
|
28
|
-
self.register_buffer('_step', tensor(0))
|
|
29
|
-
|
|
30
|
-
def get_random_state(self):
|
|
31
|
-
return randn(3, *self.image_shape)
|
|
32
|
-
|
|
33
|
-
def reset(
|
|
34
|
-
self,
|
|
35
|
-
seed = None
|
|
36
|
-
):
|
|
37
|
-
self._step.zero_()
|
|
38
|
-
state = self.get_random_state()
|
|
39
|
-
|
|
40
|
-
if self.vectorized:
|
|
41
|
-
state = repeat(state, '... -> b ...', b = self.num_envs)
|
|
42
|
-
|
|
43
|
-
return state
|
|
44
|
-
|
|
45
|
-
def step(
|
|
46
|
-
self,
|
|
47
|
-
actions,
|
|
48
|
-
):
|
|
49
|
-
state = self.get_random_state()
|
|
50
|
-
|
|
51
|
-
reward = empty(()).uniform_(*self.reward_range)
|
|
52
|
-
|
|
53
|
-
if not self.vectorized:
|
|
54
|
-
return state, reward
|
|
55
|
-
|
|
56
|
-
discrete, continuous = actions
|
|
57
|
-
assert discrete.shape[0] == self.num_envs, f'expected batch of actions for {self.num_envs} environments'
|
|
58
|
-
|
|
59
|
-
state = repeat(state, '... -> b ...', b = self.num_envs)
|
|
60
|
-
reward = repeat(reward, ' -> b', b = self.num_envs)
|
|
61
|
-
|
|
62
|
-
return state, reward
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|