dreamer4 0.0.81__tar.gz → 0.0.83__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dreamer4 might be problematic. Click here for more details.

@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dreamer4
3
- Version: 0.0.81
3
+ Version: 0.0.83
4
4
  Summary: Dreamer 4
5
5
  Project-URL: Homepage, https://pypi.org/project/dreamer4/
6
6
  Project-URL: Repository, https://github.com/lucidrains/dreamer4
@@ -1819,6 +1819,7 @@ class DynamicsWorldModel(Module):
1819
1819
  num_residual_streams = 1,
1820
1820
  keep_reward_ema_stats = False,
1821
1821
  reward_ema_decay = 0.998,
1822
+ reward_quantile_filter = (0.05, 0.95),
1822
1823
  gae_discount_factor = 0.997,
1823
1824
  gae_lambda = 0.95,
1824
1825
  ppo_eps_clip = 0.2,
@@ -2029,6 +2030,8 @@ class DynamicsWorldModel(Module):
2029
2030
  self.keep_reward_ema_stats = keep_reward_ema_stats
2030
2031
  self.reward_ema_decay = reward_ema_decay
2031
2032
 
2033
+ self.register_buffer('reward_quantile_filter', tensor(reward_quantile_filter), persistent = False)
2034
+
2032
2035
  self.register_buffer('ema_returns_mean', tensor(0.))
2033
2036
  self.register_buffer('ema_returns_var', tensor(1.))
2034
2037
 
@@ -2115,6 +2118,8 @@ class DynamicsWorldModel(Module):
2115
2118
  else:
2116
2119
  video = rearrange(init_frame, 'c vh vw -> 1 c 1 vh vw')
2117
2120
 
2121
+ batch, device = video.shape[0], video.device
2122
+
2118
2123
  # accumulate
2119
2124
 
2120
2125
  rewards = None
@@ -2125,11 +2130,21 @@ class DynamicsWorldModel(Module):
2125
2130
  values = None
2126
2131
  latents = None
2127
2132
 
2133
+ # keep track of termination, for setting the `is_truncated` field on Experience and for early stopping interaction with env
2134
+
2135
+ is_terminated = full((batch,), False, device = device)
2136
+ is_truncated = full((batch,), False, device = device)
2137
+
2138
+ episode_lens = full((batch,), 0, device = device)
2139
+
2128
2140
  # maybe time kv cache
2129
2141
 
2130
2142
  time_kv_cache = None
2131
2143
 
2132
- for i in range(max_timesteps + 1):
2144
+ step_index = 0
2145
+
2146
+ while not is_terminated.all():
2147
+ step_index += 1
2133
2148
 
2134
2149
  latents = self.video_tokenizer(video, return_latents = True)
2135
2150
 
@@ -2187,7 +2202,40 @@ class DynamicsWorldModel(Module):
2187
2202
 
2188
2203
  # pass the sampled action to the environment and get back next state and reward
2189
2204
 
2190
- next_frame, reward = env.step((sampled_discrete_actions, sampled_continuous_actions))
2205
+ env_step_out = env.step((sampled_discrete_actions, sampled_continuous_actions))
2206
+
2207
+ if len(env_step_out) == 2:
2208
+ next_frame, reward = env_step_out
2209
+ terminated = full((batch,), False)
2210
+ truncated = full((batch,), False)
2211
+
2212
+ elif len(env_step_out) == 3:
2213
+ next_frame, reward, terminated = env_step_out
2214
+ truncated = full((batch,), False)
2215
+
2216
+ elif len(env_step_out) == 4:
2217
+ next_frame, reward, terminated, truncated = env_step_out
2218
+
2219
+ # update episode lens
2220
+
2221
+ episode_lens = torch.where(is_terminated, episode_lens, episode_lens + 1)
2222
+
2223
+ # update `is_terminated`
2224
+
2225
+ # (1) - environment says it is terminated
2226
+ # (2) - previous step is truncated (this step is for bootstrap value)
2227
+
2228
+ is_terminated |= (terminated | is_truncated)
2229
+
2230
+ # update `is_truncated`
2231
+
2232
+ if step_index <= max_timesteps:
2233
+ is_truncated |= truncated
2234
+
2235
+ if step_index == max_timesteps:
2236
+ # if the step index is at the max time step allowed, set the truncated flag, if not already terminated
2237
+
2238
+ is_truncated |= ~is_terminated
2191
2239
 
2192
2240
  # batch and time dimension
2193
2241
 
@@ -2216,7 +2264,8 @@ class DynamicsWorldModel(Module):
2216
2264
  values = values,
2217
2265
  step_size = step_size,
2218
2266
  agent_index = agent_index,
2219
- lens = full((batch,), max_timesteps + 1, device = device),
2267
+ is_truncated = is_truncated,
2268
+ lens = episode_lens,
2220
2269
  is_from_world_model = False
2221
2270
  )
2222
2271
 
@@ -2284,13 +2333,22 @@ class DynamicsWorldModel(Module):
2284
2333
 
2285
2334
  decay = 1. - self.reward_ema_decay
2286
2335
 
2287
- # todo - handle distributed
2336
+ # quantile filter
2337
+
2338
+ lo, hi = torch.quantile(returns, self.reward_quantile_filter).tolist()
2339
+ returns_for_stats = returns.clamp(lo, hi)
2340
+
2341
+ # mean, var - todo - handle distributed
2288
2342
 
2289
2343
  returns_mean, returns_var = returns.mean(), returns.var()
2290
2344
 
2345
+ # ema
2346
+
2291
2347
  ema_returns_mean.lerp_(returns_mean, decay)
2292
2348
  ema_returns_var.lerp_(returns_var, decay)
2293
2349
 
2350
+ # normalize
2351
+
2294
2352
  ema_returns_std = ema_returns_var.clamp(min = 1e-5).sqrt()
2295
2353
 
2296
2354
  normed_returns = (returns - ema_returns_mean) / ema_returns_std
@@ -0,0 +1,97 @@
1
+ from __future__ import annotations
2
+ from random import choice
3
+
4
+ import torch
5
+ from torch import tensor, empty, randn, randint
6
+ from torch.nn import Module
7
+
8
+ from einops import repeat
9
+
10
+ # helpers
11
+
12
+ def exists(v):
13
+ return v is not None
14
+
15
+ # mock env
16
+
17
+ class MockEnv(Module):
18
+ def __init__(
19
+ self,
20
+ image_shape,
21
+ reward_range = (-100, 100),
22
+ num_envs = 1,
23
+ vectorized = False,
24
+ terminate_after_step = None,
25
+ rand_terminate_prob = 0.05,
26
+ can_truncate = False,
27
+ rand_truncate_prob = 0.05,
28
+ ):
29
+ super().__init__()
30
+ self.image_shape = image_shape
31
+ self.reward_range = reward_range
32
+
33
+ self.num_envs = num_envs
34
+ self.vectorized = vectorized
35
+ assert not (vectorized and num_envs == 1)
36
+
37
+ # mocking termination and truncation
38
+
39
+ self.can_terminate = exists(terminate_after_step)
40
+ self.terminate_after_step = terminate_after_step
41
+ self.rand_terminate_prob = rand_terminate_prob
42
+
43
+ self.can_truncate = can_truncate
44
+ self.rand_truncate_prob = rand_truncate_prob
45
+
46
+ self.register_buffer('_step', tensor(0))
47
+
48
+ def get_random_state(self):
49
+ return randn(3, *self.image_shape)
50
+
51
+ def reset(
52
+ self,
53
+ seed = None
54
+ ):
55
+ self._step.zero_()
56
+ state = self.get_random_state()
57
+
58
+ if self.vectorized:
59
+ state = repeat(state, '... -> b ...', b = self.num_envs)
60
+
61
+ return state
62
+
63
+ def step(
64
+ self,
65
+ actions,
66
+ ):
67
+ state = self.get_random_state()
68
+
69
+ reward = empty(()).uniform_(*self.reward_range)
70
+
71
+ if self.vectorized:
72
+ discrete, continuous = actions
73
+ assert discrete.shape[0] == self.num_envs, f'expected batch of actions for {self.num_envs} environments'
74
+
75
+ state = repeat(state, '... -> b ...', b = self.num_envs)
76
+ reward = repeat(reward, ' -> b', b = self.num_envs)
77
+
78
+ out = (state, reward)
79
+
80
+
81
+ if self.can_terminate:
82
+ shape = (self.num_envs,) if self.vectorized else (1,)
83
+ valid_step = self._step > self.terminate_after_step
84
+
85
+ terminate = (torch.rand(shape) < self.rand_terminate_prob) & valid_step
86
+
87
+ out = (*out, terminate)
88
+
89
+ # maybe truncation
90
+
91
+ if self.can_truncate:
92
+ truncate = (torch.rand(shape) < self.rand_truncate_prob) & valid_step & ~terminate
93
+ out = (*out, truncate)
94
+
95
+ self._step.add_(1)
96
+
97
+ return out
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "dreamer4"
3
- version = "0.0.81"
3
+ version = "0.0.83"
4
4
  description = "Dreamer 4"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -612,9 +612,13 @@ def test_cache_generate():
612
612
 
613
613
  @param('vectorized', (False, True))
614
614
  @param('use_signed_advantage', (False, True))
615
+ @param('env_can_terminate', (False, True))
616
+ @param('env_can_truncate', (False, True))
615
617
  def test_online_rl(
616
618
  vectorized,
617
- use_signed_advantage
619
+ use_signed_advantage,
620
+ env_can_terminate,
621
+ env_can_truncate
618
622
  ):
619
623
  from dreamer4.dreamer4 import DynamicsWorldModel, VideoTokenizer
620
624
 
@@ -649,7 +653,14 @@ def test_online_rl(
649
653
  from dreamer4.mocks import MockEnv
650
654
  from dreamer4.dreamer4 import combine_experiences
651
655
 
652
- mock_env = MockEnv((256, 256), vectorized = vectorized, num_envs = 4)
656
+ mock_env = MockEnv(
657
+ (256, 256),
658
+ vectorized = vectorized,
659
+ num_envs = 4,
660
+ terminate_after_step = 2 if env_can_terminate else None,
661
+ can_truncate = env_can_truncate,
662
+ rand_terminate_prob = 0.1
663
+ )
653
664
 
654
665
  # manually
655
666
 
@@ -1,62 +0,0 @@
1
- from __future__ import annotations
2
- from random import choice
3
-
4
- import torch
5
- from torch import tensor, empty, randn, randint
6
- from torch.nn import Module
7
-
8
- from einops import repeat
9
-
10
- # mock env
11
-
12
- class MockEnv(Module):
13
- def __init__(
14
- self,
15
- image_shape,
16
- reward_range = (-100, 100),
17
- num_envs = 1,
18
- vectorized = False
19
- ):
20
- super().__init__()
21
- self.image_shape = image_shape
22
- self.reward_range = reward_range
23
-
24
- self.num_envs = num_envs
25
- self.vectorized = vectorized
26
- assert not (vectorized and num_envs == 1)
27
-
28
- self.register_buffer('_step', tensor(0))
29
-
30
- def get_random_state(self):
31
- return randn(3, *self.image_shape)
32
-
33
- def reset(
34
- self,
35
- seed = None
36
- ):
37
- self._step.zero_()
38
- state = self.get_random_state()
39
-
40
- if self.vectorized:
41
- state = repeat(state, '... -> b ...', b = self.num_envs)
42
-
43
- return state
44
-
45
- def step(
46
- self,
47
- actions,
48
- ):
49
- state = self.get_random_state()
50
-
51
- reward = empty(()).uniform_(*self.reward_range)
52
-
53
- if not self.vectorized:
54
- return state, reward
55
-
56
- discrete, continuous = actions
57
- assert discrete.shape[0] == self.num_envs, f'expected batch of actions for {self.num_envs} environments'
58
-
59
- state = repeat(state, '... -> b ...', b = self.num_envs)
60
- reward = repeat(reward, ' -> b', b = self.num_envs)
61
-
62
- return state, reward
File without changes
File without changes
File without changes
File without changes