dreamer4 0.0.81__py3-none-any.whl → 0.0.82__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dreamer4 might be problematic. Click here for more details.

dreamer4/dreamer4.py CHANGED
@@ -1819,6 +1819,7 @@ class DynamicsWorldModel(Module):
1819
1819
  num_residual_streams = 1,
1820
1820
  keep_reward_ema_stats = False,
1821
1821
  reward_ema_decay = 0.998,
1822
+ reward_quantile_filter = (0.05, 0.95),
1822
1823
  gae_discount_factor = 0.997,
1823
1824
  gae_lambda = 0.95,
1824
1825
  ppo_eps_clip = 0.2,
@@ -2029,6 +2030,8 @@ class DynamicsWorldModel(Module):
2029
2030
  self.keep_reward_ema_stats = keep_reward_ema_stats
2030
2031
  self.reward_ema_decay = reward_ema_decay
2031
2032
 
2033
+ self.register_buffer('reward_quantile_filter', tensor(reward_quantile_filter), persistent = False)
2034
+
2032
2035
  self.register_buffer('ema_returns_mean', tensor(0.))
2033
2036
  self.register_buffer('ema_returns_var', tensor(1.))
2034
2037
 
@@ -2115,6 +2118,8 @@ class DynamicsWorldModel(Module):
2115
2118
  else:
2116
2119
  video = rearrange(init_frame, 'c vh vw -> 1 c 1 vh vw')
2117
2120
 
2121
+ batch, device = video.shape[0], video.device
2122
+
2118
2123
  # accumulate
2119
2124
 
2120
2125
  rewards = None
@@ -2125,6 +2130,11 @@ class DynamicsWorldModel(Module):
2125
2130
  values = None
2126
2131
  latents = None
2127
2132
 
2133
+ # keep track of termination, for setting the `is_truncated` field on Experience and for early stopping interaction with env
2134
+
2135
+ is_terminated = full((batch,), False, device = device)
2136
+ episode_lens = full((batch,), 0, device = device)
2137
+
2128
2138
  # maybe time kv cache
2129
2139
 
2130
2140
  time_kv_cache = None
@@ -2187,7 +2197,22 @@ class DynamicsWorldModel(Module):
2187
2197
 
2188
2198
  # pass the sampled action to the environment and get back next state and reward
2189
2199
 
2190
- next_frame, reward = env.step((sampled_discrete_actions, sampled_continuous_actions))
2200
+ env_step_out = env.step((sampled_discrete_actions, sampled_continuous_actions))
2201
+
2202
+ if len(env_step_out) == 2:
2203
+ next_frame, reward = env_step_out
2204
+ terminate = full((batch,), False)
2205
+
2206
+ elif len(env_step_out) == 3:
2207
+ next_frame, reward, terminate = env_step_out
2208
+
2209
+ # update episode lens
2210
+
2211
+ episode_lens = torch.where(is_terminated, episode_lens, episode_lens + 1)
2212
+
2213
+ # update `is_terminated`
2214
+
2215
+ is_terminated |= terminate
2191
2216
 
2192
2217
  # batch and time dimension
2193
2218
 
@@ -2203,6 +2228,11 @@ class DynamicsWorldModel(Module):
2203
2228
  video = cat((video, next_frame), dim = 2)
2204
2229
  rewards = safe_cat((rewards, reward), dim = 1)
2205
2230
 
2231
+ # early break out if all terminated
2232
+
2233
+ if is_terminated.all():
2234
+ break
2235
+
2206
2236
  # package up one experience for learning
2207
2237
 
2208
2238
  batch, device = latents.shape[0], latents.device
@@ -2216,7 +2246,8 @@ class DynamicsWorldModel(Module):
2216
2246
  values = values,
2217
2247
  step_size = step_size,
2218
2248
  agent_index = agent_index,
2219
- lens = full((batch,), max_timesteps + 1, device = device),
2249
+ is_truncated = ~is_terminated,
2250
+ lens = episode_lens,
2220
2251
  is_from_world_model = False
2221
2252
  )
2222
2253
 
@@ -2284,13 +2315,22 @@ class DynamicsWorldModel(Module):
2284
2315
 
2285
2316
  decay = 1. - self.reward_ema_decay
2286
2317
 
2287
- # todo - handle distributed
2318
+ # quantile filter
2319
+
2320
+ lo, hi = torch.quantile(returns, self.reward_quantile_filter).tolist()
2321
+ returns_for_stats = returns.clamp(lo, hi)
2322
+
2323
+ # mean, var - todo - handle distributed
2288
2324
 
2289
2325
  returns_mean, returns_var = returns.mean(), returns.var()
2290
2326
 
2327
+ # ema
2328
+
2291
2329
  ema_returns_mean.lerp_(returns_mean, decay)
2292
2330
  ema_returns_var.lerp_(returns_var, decay)
2293
2331
 
2332
+ # normalize
2333
+
2294
2334
  ema_returns_std = ema_returns_var.clamp(min = 1e-5).sqrt()
2295
2335
 
2296
2336
  normed_returns = (returns - ema_returns_mean) / ema_returns_std
dreamer4/mocks.py CHANGED
@@ -7,6 +7,11 @@ from torch.nn import Module
7
7
 
8
8
  from einops import repeat
9
9
 
10
+ # helpers
11
+
12
+ def exists(v):
13
+ return v is not None
14
+
10
15
  # mock env
11
16
 
12
17
  class MockEnv(Module):
@@ -15,7 +20,9 @@ class MockEnv(Module):
15
20
  image_shape,
16
21
  reward_range = (-100, 100),
17
22
  num_envs = 1,
18
- vectorized = False
23
+ vectorized = False,
24
+ terminate_after_step = None,
25
+ rand_terminate_prob = 0.05
19
26
  ):
20
27
  super().__init__()
21
28
  self.image_shape = image_shape
@@ -25,6 +32,12 @@ class MockEnv(Module):
25
32
  self.vectorized = vectorized
26
33
  assert not (vectorized and num_envs == 1)
27
34
 
35
+ # mocking termination
36
+
37
+ self.can_terminate = exists(terminate_after_step)
38
+ self.terminate_after_step = terminate_after_step
39
+ self.rand_terminate_prob = rand_terminate_prob
40
+
28
41
  self.register_buffer('_step', tensor(0))
29
42
 
30
43
  def get_random_state(self):
@@ -50,13 +63,23 @@ class MockEnv(Module):
50
63
 
51
64
  reward = empty(()).uniform_(*self.reward_range)
52
65
 
53
- if not self.vectorized:
54
- return state, reward
66
+ if self.vectorized:
67
+ discrete, continuous = actions
68
+ assert discrete.shape[0] == self.num_envs, f'expected batch of actions for {self.num_envs} environments'
69
+
70
+ state = repeat(state, '... -> b ...', b = self.num_envs)
71
+ reward = repeat(reward, ' -> b', b = self.num_envs)
72
+
73
+ out = (state, reward)
74
+
75
+ if self.can_terminate:
76
+ terminate = (
77
+ (torch.rand((self.num_envs)) < self.rand_terminate_prob) &
78
+ (self._step > self.terminate_after_step)
79
+ )
55
80
 
56
- discrete, continuous = actions
57
- assert discrete.shape[0] == self.num_envs, f'expected batch of actions for {self.num_envs} environments'
81
+ out = (*out, terminate)
58
82
 
59
- state = repeat(state, '... -> b ...', b = self.num_envs)
60
- reward = repeat(reward, ' -> b', b = self.num_envs)
83
+ self._step.add_(1)
61
84
 
62
- return state, reward
85
+ return out
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dreamer4
3
- Version: 0.0.81
3
+ Version: 0.0.82
4
4
  Summary: Dreamer 4
5
5
  Project-URL: Homepage, https://pypi.org/project/dreamer4/
6
6
  Project-URL: Repository, https://github.com/lucidrains/dreamer4
@@ -0,0 +1,8 @@
1
+ dreamer4/__init__.py,sha256=Jssh1obzDRtTfBLZl36kXge1cIQlMjf_8DyjPulvKSk,183
2
+ dreamer4/dreamer4.py,sha256=1WeHCCJO-wa85Ki9uP8FR70pFTro46LF4WyFMxSD90I,110556
3
+ dreamer4/mocks.py,sha256=S1kiENV3kHM1L6pBOLDqLVCZJg7ZydEmagiNb8sFIXc,2077
4
+ dreamer4/trainers.py,sha256=D2b7WTgTHElLhIWLFgl2Ct2knGJLTk91HHpC5UkNvG0,14028
5
+ dreamer4-0.0.82.dist-info/METADATA,sha256=by1KbrcKVy0epdLDLAiSN1MXR_FBgyD40Ol_Mn5iZNM,3065
6
+ dreamer4-0.0.82.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ dreamer4-0.0.82.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ dreamer4-0.0.82.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- dreamer4/__init__.py,sha256=Jssh1obzDRtTfBLZl36kXge1cIQlMjf_8DyjPulvKSk,183
2
- dreamer4/dreamer4.py,sha256=8IxpSzfpm_PLqwJ1YReVRuLuvbvQJNrhIKCvqK35jdQ,109317
3
- dreamer4/mocks.py,sha256=Oi91Yv1oK0E-Wz-KDkf79xoyWzIXCvMLCr0WYCpJDLA,1482
4
- dreamer4/trainers.py,sha256=D2b7WTgTHElLhIWLFgl2Ct2knGJLTk91HHpC5UkNvG0,14028
5
- dreamer4-0.0.81.dist-info/METADATA,sha256=8qDj5Pw3N_hqXOrGFCnP6bgbTqB7B2dn5_9FRn4mExw,3065
6
- dreamer4-0.0.81.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- dreamer4-0.0.81.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- dreamer4-0.0.81.dist-info/RECORD,,