dreamer4 0.0.80__py3-none-any.whl → 0.0.82__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dreamer4 might be problematic. Click here for more details.

dreamer4/dreamer4.py CHANGED
@@ -1817,6 +1817,9 @@ class DynamicsWorldModel(Module):
1817
1817
  continuous_action_loss_weight: float | list[float] = 1.,
1818
1818
  num_latent_genes = 0, # for carrying out evolution within the dreams https://web3.arxiv.org/abs/2503.19037
1819
1819
  num_residual_streams = 1,
1820
+ keep_reward_ema_stats = False,
1821
+ reward_ema_decay = 0.998,
1822
+ reward_quantile_filter = (0.05, 0.95),
1820
1823
  gae_discount_factor = 0.997,
1821
1824
  gae_lambda = 0.95,
1822
1825
  ppo_eps_clip = 0.2,
@@ -2022,6 +2025,16 @@ class DynamicsWorldModel(Module):
2022
2025
  self.value_clip = value_clip
2023
2026
  self.policy_entropy_weight = value_clip
2024
2027
 
2028
+ # rewards related
2029
+
2030
+ self.keep_reward_ema_stats = keep_reward_ema_stats
2031
+ self.reward_ema_decay = reward_ema_decay
2032
+
2033
+ self.register_buffer('reward_quantile_filter', tensor(reward_quantile_filter), persistent = False)
2034
+
2035
+ self.register_buffer('ema_returns_mean', tensor(0.))
2036
+ self.register_buffer('ema_returns_var', tensor(1.))
2037
+
2025
2038
  # loss related
2026
2039
 
2027
2040
  self.flow_loss_normalizer = LossNormalizer(1)
@@ -2105,6 +2118,8 @@ class DynamicsWorldModel(Module):
2105
2118
  else:
2106
2119
  video = rearrange(init_frame, 'c vh vw -> 1 c 1 vh vw')
2107
2120
 
2121
+ batch, device = video.shape[0], video.device
2122
+
2108
2123
  # accumulate
2109
2124
 
2110
2125
  rewards = None
@@ -2115,6 +2130,11 @@ class DynamicsWorldModel(Module):
2115
2130
  values = None
2116
2131
  latents = None
2117
2132
 
2133
+ # keep track of termination, for setting the `is_truncated` field on Experience and for early stopping interaction with env
2134
+
2135
+ is_terminated = full((batch,), False, device = device)
2136
+ episode_lens = full((batch,), 0, device = device)
2137
+
2118
2138
  # maybe time kv cache
2119
2139
 
2120
2140
  time_kv_cache = None
@@ -2177,7 +2197,22 @@ class DynamicsWorldModel(Module):
2177
2197
 
2178
2198
  # pass the sampled action to the environment and get back next state and reward
2179
2199
 
2180
- next_frame, reward = env.step((sampled_discrete_actions, sampled_continuous_actions))
2200
+ env_step_out = env.step((sampled_discrete_actions, sampled_continuous_actions))
2201
+
2202
+ if len(env_step_out) == 2:
2203
+ next_frame, reward = env_step_out
2204
+ terminate = full((batch,), False)
2205
+
2206
+ elif len(env_step_out) == 3:
2207
+ next_frame, reward, terminate = env_step_out
2208
+
2209
+ # update episode lens
2210
+
2211
+ episode_lens = torch.where(is_terminated, episode_lens, episode_lens + 1)
2212
+
2213
+ # update `is_terminated`
2214
+
2215
+ is_terminated |= terminate
2181
2216
 
2182
2217
  # batch and time dimension
2183
2218
 
@@ -2193,6 +2228,11 @@ class DynamicsWorldModel(Module):
2193
2228
  video = cat((video, next_frame), dim = 2)
2194
2229
  rewards = safe_cat((rewards, reward), dim = 1)
2195
2230
 
2231
+ # early break out if all terminated
2232
+
2233
+ if is_terminated.all():
2234
+ break
2235
+
2196
2236
  # package up one experience for learning
2197
2237
 
2198
2238
  batch, device = latents.shape[0], latents.device
@@ -2206,7 +2246,8 @@ class DynamicsWorldModel(Module):
2206
2246
  values = values,
2207
2247
  step_size = step_size,
2208
2248
  agent_index = agent_index,
2209
- lens = full((batch,), max_timesteps + 1, device = device),
2249
+ is_truncated = ~is_terminated,
2250
+ lens = episode_lens,
2210
2251
  is_from_world_model = False
2211
2252
  )
2212
2253
 
@@ -2267,11 +2308,41 @@ class DynamicsWorldModel(Module):
2267
2308
 
2268
2309
  world_model_forward_context = torch.no_grad if only_learn_policy_value_heads else nullcontext
2269
2310
 
2311
+ # maybe keep track returns statistics and normalize returns and values before calculating advantage, as done in dreamer v3
2312
+
2313
+ if self.keep_reward_ema_stats:
2314
+ ema_returns_mean, ema_returns_var = self.ema_returns_mean, self.ema_returns_var
2315
+
2316
+ decay = 1. - self.reward_ema_decay
2317
+
2318
+ # quantile filter
2319
+
2320
+ lo, hi = torch.quantile(returns, self.reward_quantile_filter).tolist()
2321
+ returns_for_stats = returns.clamp(lo, hi)
2322
+
2323
+ # mean, var - todo - handle distributed
2324
+
2325
+ returns_mean, returns_var = returns.mean(), returns.var()
2326
+
2327
+ # ema
2328
+
2329
+ ema_returns_mean.lerp_(returns_mean, decay)
2330
+ ema_returns_var.lerp_(returns_var, decay)
2331
+
2332
+ # normalize
2333
+
2334
+ ema_returns_std = ema_returns_var.clamp(min = 1e-5).sqrt()
2335
+
2336
+ normed_returns = (returns - ema_returns_mean) / ema_returns_std
2337
+ normed_old_values = (old_values - ema_returns_mean) / ema_returns_std
2338
+
2339
+ advantage = normed_returns - normed_old_values
2340
+ else:
2341
+ advantage = returns - old_values
2342
+
2270
2343
  # apparently they just use the sign of the advantage
2271
2344
  # https://arxiv.org/abs/2410.04166v1
2272
2345
 
2273
- advantage = returns - old_values
2274
-
2275
2346
  if use_signed_advantage:
2276
2347
  advantage = advantage.sign()
2277
2348
  else:
dreamer4/mocks.py CHANGED
@@ -7,6 +7,11 @@ from torch.nn import Module
7
7
 
8
8
  from einops import repeat
9
9
 
10
+ # helpers
11
+
12
+ def exists(v):
13
+ return v is not None
14
+
10
15
  # mock env
11
16
 
12
17
  class MockEnv(Module):
@@ -15,7 +20,9 @@ class MockEnv(Module):
15
20
  image_shape,
16
21
  reward_range = (-100, 100),
17
22
  num_envs = 1,
18
- vectorized = False
23
+ vectorized = False,
24
+ terminate_after_step = None,
25
+ rand_terminate_prob = 0.05
19
26
  ):
20
27
  super().__init__()
21
28
  self.image_shape = image_shape
@@ -25,6 +32,12 @@ class MockEnv(Module):
25
32
  self.vectorized = vectorized
26
33
  assert not (vectorized and num_envs == 1)
27
34
 
35
+ # mocking termination
36
+
37
+ self.can_terminate = exists(terminate_after_step)
38
+ self.terminate_after_step = terminate_after_step
39
+ self.rand_terminate_prob = rand_terminate_prob
40
+
28
41
  self.register_buffer('_step', tensor(0))
29
42
 
30
43
  def get_random_state(self):
@@ -50,13 +63,23 @@ class MockEnv(Module):
50
63
 
51
64
  reward = empty(()).uniform_(*self.reward_range)
52
65
 
53
- if not self.vectorized:
54
- return state, reward
66
+ if self.vectorized:
67
+ discrete, continuous = actions
68
+ assert discrete.shape[0] == self.num_envs, f'expected batch of actions for {self.num_envs} environments'
69
+
70
+ state = repeat(state, '... -> b ...', b = self.num_envs)
71
+ reward = repeat(reward, ' -> b', b = self.num_envs)
72
+
73
+ out = (state, reward)
74
+
75
+ if self.can_terminate:
76
+ terminate = (
77
+ (torch.rand((self.num_envs)) < self.rand_terminate_prob) &
78
+ (self._step > self.terminate_after_step)
79
+ )
55
80
 
56
- discrete, continuous = actions
57
- assert discrete.shape[0] == self.num_envs, f'expected batch of actions for {self.num_envs} environments'
81
+ out = (*out, terminate)
58
82
 
59
- state = repeat(state, '... -> b ...', b = self.num_envs)
60
- reward = repeat(reward, ' -> b', b = self.num_envs)
83
+ self._step.add_(1)
61
84
 
62
- return state, reward
85
+ return out
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dreamer4
3
- Version: 0.0.80
3
+ Version: 0.0.82
4
4
  Summary: Dreamer 4
5
5
  Project-URL: Homepage, https://pypi.org/project/dreamer4/
6
6
  Project-URL: Repository, https://github.com/lucidrains/dreamer4
@@ -0,0 +1,8 @@
1
+ dreamer4/__init__.py,sha256=Jssh1obzDRtTfBLZl36kXge1cIQlMjf_8DyjPulvKSk,183
2
+ dreamer4/dreamer4.py,sha256=1WeHCCJO-wa85Ki9uP8FR70pFTro46LF4WyFMxSD90I,110556
3
+ dreamer4/mocks.py,sha256=S1kiENV3kHM1L6pBOLDqLVCZJg7ZydEmagiNb8sFIXc,2077
4
+ dreamer4/trainers.py,sha256=D2b7WTgTHElLhIWLFgl2Ct2knGJLTk91HHpC5UkNvG0,14028
5
+ dreamer4-0.0.82.dist-info/METADATA,sha256=by1KbrcKVy0epdLDLAiSN1MXR_FBgyD40Ol_Mn5iZNM,3065
6
+ dreamer4-0.0.82.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ dreamer4-0.0.82.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ dreamer4-0.0.82.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- dreamer4/__init__.py,sha256=Jssh1obzDRtTfBLZl36kXge1cIQlMjf_8DyjPulvKSk,183
2
- dreamer4/dreamer4.py,sha256=65MN3_ZqB3w5oXiWvpjRIHxxTwsACmnpw1HTd0GVJkU,108143
3
- dreamer4/mocks.py,sha256=Oi91Yv1oK0E-Wz-KDkf79xoyWzIXCvMLCr0WYCpJDLA,1482
4
- dreamer4/trainers.py,sha256=D2b7WTgTHElLhIWLFgl2Ct2knGJLTk91HHpC5UkNvG0,14028
5
- dreamer4-0.0.80.dist-info/METADATA,sha256=kyaYELFKTxlaWrjMYHGFiY1LtuOKl1_fBC1GSd9eo2A,3065
6
- dreamer4-0.0.80.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- dreamer4-0.0.80.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- dreamer4-0.0.80.dist-info/RECORD,,